各种优化器的比较
1 import torch
2
3 import torch.utils.data as Data
4
5 import torch.nn.functional as F
6
7 from torch.autograd import Variable
8
9 import matplotlib.pyplot as plt
10
11
12
13 # 超参数
14
15 LR = 0.01
16
17 BATCH_SIZE = 32
18
19 EPOCH = 12
20
21
22
23 # 生成假数据
24
25 # torch.unsqueeze() 的作用是将一维变二维,torch只能处理二维的数据
26
27 x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1) # x data (tensor), shape(100, 1)
28
29 # 0.2 * torch.rand(x.size())增加噪点
30
31 y = x.pow(2) + 0.1 * torch.normal(torch.zeros(*x.size()))
32
33
34
35 # 输出数据图
36
37 # plt.scatter(x.numpy(), y.numpy())
38
39 # plt.show()
40
41
42
43 torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)
44
45 loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
46
47
48
49
50
51 class Net(torch.nn.Module):
52
53 # 初始化
54
55 def __init__(self):
56
57 super(Net, self).__init__()
58
59 self.hidden = torch.nn.Linear(1, 20)
60
61 self.predict = torch.nn.Linear(20, 1)
62
63
64
65 # 前向传递
66
67 def forward(self, x):
68
69 x = F.relu(self.hidden(x))
70
71 x = self.predict(x)
72
73 return x
74
75
76
77 net_SGD = Net()
78
79 net_Momentum = Net()
80
81 net_RMSProp = Net()
82
83 net_Adam = Net()
84
85
86
87 nets = [net_SGD, net_Momentum, net_RMSProp, net_Adam]
88
89
90
91 opt_SGD = torch.optim.SGD(net_SGD.parameters(), lr=LR)
92
93 opt_Momentum = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
94
95 opt_RMSProp = torch.optim.RMSprop(net_RMSProp.parameters(), lr=LR, alpha=0.9)
96
97 opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
98
99 optimizers = [opt_SGD, opt_Momentum, opt_RMSProp, opt_Adam]
100
101
102
103 loss_func = torch.nn.MSELoss()
104
105
106
107 loss_his = [[], [], [], []] # 记录损失
108
109
110
111 for epoch in range(EPOCH):
112
113 print(epoch)
114
115 for step, (batch_x, batch_y) in enumerate(loader):
116
117 b_x = Variable(batch_x)
118
119 b_y = Variable(batch_y)
120
121
122
123 for net, opt,l_his in zip(nets, optimizers, loss_his):
124
125 output = net(b_x) # get output for every net
126
127 loss = loss_func(output, b_y) # compute loss for every net
128
129 opt.zero_grad() # clear gradients for next train
130
131 loss.backward() # backpropagation, compute gradients
132
133 opt.step() # apply gradients
134
135 l_his.append(loss.data.numpy()) # loss recoder
136
137 labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']
138
139 for i, l_his in enumerate(loss_his):
140
141 plt.plot(l_his, label=labels[i])
142
143 plt.legend(loc='best')
144
145 plt.xlabel('Steps')
146
147 plt.ylabel('Loss')
148
149 plt.ylim((0, 0.2))
150
151 plt.show()
152
153
154
155
156
157
作者:你的雷哥
本文版权归作者所有,欢迎转载,但未经作者同意必须在文章页面给出原文连接,否则保留追究法律责任的权利。