通过指数衰减的方法设置[优化神经网络损失函数的]学习率

一,什么是学习率?
在神经网络中会有一个损失函数J(X)来衡量神经网络的输出结果和正确结果之间的误差,如下图J(X)为某个神经网络的损失函数,其实X假设为神经网络中的参数 。优化神经网络的目的就是使它的损失函数值达到最小 。
由下图所示,设置当前的参数和损失函数值对应下图小圆点的位置,那么梯度下降算法会将参数向X轴左侧移动,从而使得小圆点朝着箭头方向移动 。参数的梯度可以通过求偏导数的方式计算,对应参数X,其梯度为
。有了梯度,还需要学习率
来定义每次参数更新的幅度 。

通过指数衰减的方法设置[优化神经网络损失函数的]学习率

文章插图
通过参数的梯度和学习率,参数更新的公式为:
二,通过指数衰减的方法设置[优化神经网络损失函数的]学习率
学习率决定了参数(例如网络中的权值)每次更新的幅度,如果幅度过大,那么可能导致参数在极优值的两侧来回移动 。而如果学习率过小,虽然也能保证极优值的收敛,但是会大大降低优化速度 。
为了解决设定学习率的问题,提供了一种更加灵活的学习率设置方法-----指数衰减法( tf.train.函数 ) 。通过这个函数,可以先用一个较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减少学习率,使模型在训练后期更加稳定 。tf.train.函数会指数级地减少学习率,这个函数的实现代码如下:
其中,e ------当前衰减过后的学习率
----- 初始学习率
----- 衰减系数
/控制衰减的速度,其中 是一个人为指定的常量,而 则是从0开始慢慢加上去的一个变量,因此 /便会随着迭代慢慢变大,使得学习率衰减的程度变得越来越大 。
接下来是 tf.train.函数 的参数说明:
learning_rate = tf.train.exponential_decay(0.1,global_step,100,0.96,staircase=True)
上式的参数表示:
初始学习率为0.1 。
通过指数衰减的方法设置[优化神经网络损失函数的]学习率

文章插图
当 = True时,学习率在衰减过程中呈阶梯状,(/)会被转化为整数 。当 = False时,学习率在衰减过程中为连续的曲线 。(下面的图会清晰地展示) 。[默认为False] 。
上面的100和0.96分别是 和。若=True,则可以表示每经过100轮的迭代,学习率乘0.96 。
三,代码输出通过 tf.train.函数衰减的学习率
通过指数衰减的方法设置[优化神经网络损失函数的]学习率

文章插图
蓝色为 =True时的学习率(y轴)随 (x轴)分布 。
红色为 =False时的学习率(y轴)随 (x轴)分布 。
上图的代码:
import tensorflow as tfimport numpy as npimport matplotlib.pyplot as pltlearning_rate = 0.1 //初始学习率为0.1decay_rate = 0.96 //衰减系数为0.96global_steps = 1000 //表示函数x轴为0~1000decay_steps = 100 //衰减速度,它越小,学习率衰减越快global_ = tf.Variable(tf.constant(0))c = tf.train.exponential_decay(learning_rate,global_,decay_steps,decay_rate,staircase=True) //staircase=True的情况d = tf.train.exponential_decay(learning_rate,global_,decay_steps,decay_rate,staircase=False) //staircase=False的情况C_ = [] //保存staircase=True的情况的exponential_decay()输出值D_ = [] //保存staircase=False的情况的exponential_decay()输出值with tf.Session() assess:for i in range(global_steps):c_ = sess.run(c,feed_dict={global_:i})C_.append(c_)d_ = sess.run(d,feed_dict={global_:i})D_.append(d_)plt.figure(1)plt.plot(range(global_steps),D_,'r-') //显示x轴为0~global_steps,y轴为D_的函数图像 。并把图像画成红色plt.plot(range(global_steps),C_,'b-') //显示x轴为0~global_steps,y轴为C_的函数图像 。并把图像画成蓝色plt.show()