0
点赞
收藏
分享

微信扫一扫

随手写的numpy实现一元线性回归(拟合三次函数)

alanwhy 2022-02-03 阅读 50
import numpy as np
import matplotlib.pyplot as plt

learning_rate=15    # 多次调整
epochs=1000
# input_features=1
# input_size=1000
# output_features=1
# output_size=1000

w=np.ones((1000,))
b=np.ones((1000,))
x=np.random.randn(1000,)
y=np.array([xi**3 for xi in x])   
print(x.shape,y.shape) 
plt.scatter(x,y)
plt.show()
(1000,) (1000,)

在这里插入图片描述

$$

loss=\frac{1}{n}(wx+b-y)^2
\

\frac{\delta loss}{\delta w}=\frac{2x}{n}[wx+(b-y)]
\

\frac{\delta loss}{\delta b}=\frac{2}{n}[wx+b-y]
\

$$

def getloss(pred,label):
    """
    pred:prediction array whose shape is (n,)
    label:label array whose shape is (n,)
    """
    # using MAE loss function
    n=len(pred)
    loss=np.sum((pred-label))/n
    return loss

def gradient_decent(init_weight,init_bias,x_train,y_train,epochs,lr):
    loss=0.
    pred=0.
    w=init_weight
    b=init_bias
    n=len(x_train)
    for epoch in range(epochs):
        if (epoch+1)%100==0:
            print("Epoch {}/{}:".format(epoch+1,epochs))
        # 前向传播  
        pred=w*x_train+b
        loss=getloss(pred,y_train)
        # 更新损失
        grad_w=(pred-y_train)*(2*x_train)/n
        grad_b=(pred-y_train)*(2/n)
        w=w-learning_rate*grad_w
        b=b-learning_rate*grad_b
        if (epoch+1)%100==0:
            print("Loss:{}".format(loss)) 
    return w,b
    
w,b=gradient_decent(w,b,x,y,epochs,learning_rate)
plot_x=np.linspace(-3,3,1000)
prediction=w*plot_x+b
print(plot_x.shape,prediction.shape)
plt.scatter(plot_x,prediction,c='r')
plt.scatter(x,y)
plt.show()
Epoch 100/1000:
Loss:4.1350661425936813e-16
Epoch 200/1000:
Loss:4.1350661425936813e-16
Epoch 300/1000:
Loss:4.1350661425936813e-16
Epoch 400/1000:
Loss:4.1350661425936813e-16
Epoch 500/1000:
Loss:4.1350661425936813e-16
Epoch 600/1000:
Loss:4.1350661425936813e-16
Epoch 700/1000:
Loss:4.1350661425936813e-16
Epoch 800/1000:
Loss:4.1350661425936813e-16
Epoch 900/1000:
Loss:4.1350661425936813e-16
Epoch 1000/1000:
Loss:4.1350661425936813e-16
(1000,) (1000,)

在这里插入图片描述

举报

相关推荐

0 条评论