https://www.bilibili.com/video/av22530538/?p=17
#lesson17 学习率
#学习率learing_rate:每次参数更新的幅度
#wn+1 = wn - learning_rate德尔塔
#更新后的参数 当前参数 学习率 损失函数的梯度(导数)
#eg:损失函数loss = (W+1)^2 梯度 德尔塔 = 偏导loss/w =2w+2
#参数w初始化为5,学习率为0。2则
#1次 参数w:5 5-0.2*(2*5+2)=2.6
#2次 参数w:2.6 2.6-0.2(2*2.6+2)=1.16
#3次 参数:1.16 1.16-0.2*(2*1.16+2)=0.296
#。。。
#学习率设置多少合适?
#依旧以loss=(w+1)^2为列
#开始令learning_rate = 1
#之后令 learning_rate = 0.001
#学习率大令震荡不收敛,学习率小令收敛速度慢
#指数衰减学习率
#learning_rate = LEARNING_RATE_BASE,LEARNING_RATE_DECAY global_step taircase=True
#学习率基数 学习率初始值 学习率衰减率(0,1) 多少轮更新一次学习率 =总样本数/BATCH_SIZE
#
#global_step = tf.Variable(0,trainable=False)
#learngin_rate = tf.train.exponential_decay(
#LEARNING_RATE_BASE,global_step,LEARNING_RATE_STEP,LEARNING_RATE_DECAY,staircase=Trus)
#设损失函数loss=(w+1)^2,令w初值是常数10.反向传输就是求最优w,即求最小loss对应的w值
#使用指数衰减的学习率,在迭代初期得到较高的下降速度,可以在较小的训练轮数下取的更有收敛度。
import tensorflow as tf
#最初学习率
LEARNING_RATE_BASE = 0.1
#学习率衰减率
LEARNING_RATE_DELAY = 0.99
#喂入多少轮BATCH_SIZE后,更新一次学习率,一般设为:总数样本/BATCH_SIZE
LEARNING_RATE_STEP = 1
#运行了几轮BATCH_SIZE的计算器,初值给0,设为不被训练
global_step = tf.Variable(0, trainable=False)
#定义指数下降学习率
learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, LEARNING_RATE_STEP
,LEARNING_RATE_DELAY,staircase=True)
#定义待优化参数,初值给10
w = tf.Variable(tf.constant(5, dtype=tf.float32))
#定义损失函数loss
loss = tf.square(w+1)
#定义反向传播方法
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step)
#生成会话,训练40轮
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
for i in range(40):
sess.run(train_step)
learning_rate_val = sess.run(learning_rate)
global_step_val = sess.run(global_step)
w_val = sess.run(w)
loss_val = sess.run(loss)
print("After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f",(i,global_step_val,w_val,learning_rate_val,loss_val))
结果:
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (0, 1, 3.8, 0.099, 23.04))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (1, 2, 2.8495998, 0.09801, 14.819419))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (2, 3, 2.0950012, 0.0970299, 9.579033))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (3, 4, 1.494386, 0.096059605, 6.2219615))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (4, 5, 1.0151665, 0.09509901, 4.060896))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (5, 6, 0.6318858, 0.094148025, 2.6630511))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (6, 7, 0.32460818, 0.09320655, 1.7545869))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (7, 8, 0.077683866, 0.09227448, 1.1614027))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (8, 9, -0.121201575, 0.09135174, 0.77228665))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (9, 10, -0.2817611, 0.09043822, 0.5158671))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (10, 11, -0.4116736, 0.089533836, 0.34612796))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (11, 12, -0.51702386, 0.0886385, 0.23326595))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (12, 13, -0.60264444, 0.08775212, 0.15789144))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (13, 14, -0.672382, 0.0868746, 0.107333556))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (14, 15, -0.7293054, 0.08600585, 0.07327557))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (15, 16, -0.77586806, 0.085145794, 0.050235126))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (16, 17, -0.81403583, 0.084294334, 0.03458267))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (17, 18, -0.8453873, 0.08345139, 0.023905093))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (18, 19, -0.8711926, 0.08261688, 0.016591353))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (19, 20, -0.8924759, 0.08179071, 0.011561431))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (20, 21, -0.9100648, 0.080972806, 0.008088337))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (21, 22, -0.92462945, 0.08016308, 0.00568072))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (22, 23, -0.93671334, 0.079361446, 0.0040052016))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (23, 24, -0.9467584, 0.07856783, 0.002834669))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (24, 25, -0.95512456, 0.077782154, 0.0020138053))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (25, 26, -0.9621056, 0.077004336, 0.0014359877))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (26, 27, -0.96794164, 0.076234296, 0.0010277383))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (27, 28, -0.9728295, 0.07547195, 0.00073823496))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (28, 29, -0.97693074, 0.07471723, 0.0005321909))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (29, 30, -0.9803781, 0.073970065, 0.0003850193))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (30, 31, -0.98328096, 0.073230356, 0.0002795264))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (31, 32, -0.98572963, 0.07249805, 0.00020364333))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (32, 33, -0.9877988, 0.071773075, 0.00014886903))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (33, 34, -0.98955023, 0.071055345, 0.00010919763))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (34, 35, -0.9910353, 0.0703448, 8.036616e-05))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (35, 36, -0.9922965, 0.069641344, 5.9343653e-05))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (36, 37, -0.99336946, 0.06894493, 4.396406e-05))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (37, 38, -0.99428374, 0.068255484, 3.2675678e-05))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (38, 39, -0.9950641, 0.06757293, 2.436331e-05))
('After %s steps:global_step is %f, w is %f,learning rate is %f,loss is %f', (39, 40, -0.9957312, 0.066897206, 1.8222867e-05))