0
点赞
收藏
分享

微信扫一扫

DRL | 07 Sarsa: On-policy 时序差分控制方法

导读



深度强化学习是近几年比较热门的技术,也是被很多大牛看做是实现真正的人工智能的最理想的工具。


今天这篇文章,我们一起来了解一下On-policy 时序差分方法 Sarsa,我们首先一起回顾TD算法,然后讲解Sarsa算法的思想,算法的细节,并通过一个具体的案例讲解Sarsa算法在特定场景下的实现



接下来让我们一起走进今天的文章吧!



1

回顾



1、TD思想


时序差分(temporal-difference (TD) learning)使用了一种自举(bootstrapping)的方法。这种方法可以简单理解为:利用已有信息自己估算未来的值,即自己推举一个值。


通过自举,TD算法无需获得完整的模型,即可以实现免模型预测与控制;同志TD算法无需等待回合结束就可以得到预测结果。也就是说,TD方法融合了蒙特卡罗方法和动态规划方法的优点,并避免了二者的缺点:一方面,TD方法可以在没有环境动态模型的情况下(即免模型)直接从原始经验中学习。另一方面,TD方法在一定程度上基于其他已知的估计来更新估计,而不需要等待最终结果(自举)。


2、TD(0)


TD方法无需等到回合结束,只需要等待到下一个时间步骤。如TD(0)算法,在在t+1时刻,使用观察到的奖励  估计的  进行更新,更新的目标为:  。更新的公式为:





完整算法如下:


DRL | 07 Sarsa: On-policy 时序差分控制方法_初始化


3、TD方法的优点


TD算法主要用来解决动态规划和蒙特卡罗所解决不了的问题,因此,他们的缺点,就是TD算法的优点。主要有:


与动态规划方法相比:TD方法无需环境、奖励和下一状态概率分布的模型;即在免模型情况下。


与蒙特卡罗方法相比:TD方法可以以在线的、完全增量的方式实现,只需要一个时间步,而无需像蒙特卡罗一样要等到回合结束。这可以极大的提升方案的性能。


2

Sarsa算法




接下来我们学习一个具体的比较经典的时序差分方法。


1、简述


前面我们讲到了TD方法做预测,预测的目的是为了更好地决策,因此,我们应该考虑如何使用TD方法来解决控制问题。


TD方法也分为同策略和异策略两种,同策略TD方法,最经典的就是Sarsa。


2、算法


1、算法思想


首先,对于一个回合,智能体获取当前的状态,并执行某个动作,环境根据智能体的动作给予对应的奖励并更新动作,因此,我们可以得到如下的序列:




我们从中取出一个五元组:  ,表示从当前状态更新到下一个状态的变化。即智能体得到当前的状态  ,并根据某个策略执行动作  ,环境根据智能体的动作,给智能体对应的奖励  ,并得到新的状态动作对 


为什么要这样设置五元组呢?


强化学习的目的,是为了得到最优的奖励,最简单的描述形式,就是在某状态下执行某个动作的奖励最大——即动作价值函数Q函数,Q函数与环境状态和智能体的动作有关,而智能体在执行动作的时候,是按照某种策略执行的,所以一般我们用 


我们使用Q函数带入到之前的TD(0)算法中:




就变成了:




只要  不是最终状态,就会一直更新,当 


这个表示从一个状态-动作对过渡到下一个状态-动作对的五元组,将其应用到TD算法中,我们称这个算法为Sarsa。这五个字母,分别表示:状态、动作、奖励、新状态、新动作。


2、算法


理解了上面之后,我们就可以来看Sarsa算法本身了。


DRL | 07 Sarsa: On-policy 时序差分控制方法_差分_02


Sarsa算法与前面的思想基本类似,包含三个部分:动作奖励更新


具体为:


1、初始化状态
2、在状态中,基于Q并使用贪婪策略选择动作
3、循环执行直到终止:
(1)更新状态
(2)在状态中,基于Q并使用贪婪策略更新动作
(3)使用公式更新Q
(4)将新状态与新动作存为当前状态与当前动作


3

示例:Windy Gridworld



1、Windy Gridworld介绍


我们以一个比较经典的案例来介绍Sarsa算法


Windy Gridworld是说,智能体在一个有风的环境中,想从某一个位置到达另外一个位置。


DRL | 07 Sarsa: On-policy 时序差分控制方法_差分_03


以上图为例,这个环境是一个7×10的矩形区域。智能体从位置S出发,想到达目的地G。如果不考虑其他因素,最快的路径是从S直接向右走,走到G。


有风的环境在这个案例中,是说有从下往上吹得风,最下面的一行数字表示的是风力大小,如果走到图中的某个位置,风力会让智能体原本到达的位置再向上移动风力大小个位置(智能体的移动不会超出边界)


DRL | 07 Sarsa: On-policy 时序差分控制方法_初始化_04


例如智能体从A走到B1,只需要向右走,但是由于A所在列的风力为1,所以智能体实际走到的位置是B2位置。


追求的目标,就是让智能体从S走到G,并且,奖励最大(走的路程最短)。为此,我们设置每走一步,奖励为-1。


2、分析与实现


1、环境建模


首先,我们要对环境进行建模,主要包括两方面:


环境的静态属性(大小,风力);
环境对智能体的动态反馈(包括影响及奖励)


静态属性代码如下:


# 环境
HEIGHT = 7 # 高
WIDTH = 10 # 宽
WIND = [0, 0, 0, 1, 1, 1, 2, 2, 1, 0] # 每一列风力大小


环境对智能体的反馈,我们在智能体相关的代码中实现即可。


一方面是环境会对智能体的动作造成影响,改变环境状态;另一方面,环境会根据智能体的动作给予对应的奖励。

2、智能体建模


对于智能体,我们首先要定义智能体的动作空间:


# 动作空间
ACTION_UP = 0 # 上
ACTION_DOWN = 1 # 下
ACTION_LEFT = 2 # 左
ACTION_RIGHT = 3 # 右
ACTIONS = [ACTION_UP, ACTION_DOWN, ACTION_LEFT, ACTION_RIGHT] # 动作空间


然后需要定义智能体的起点和重点,为了更好地描述,我们使用格式:


[行,列]


来表示位置,我们定义最左上角为[0,0]点,最右下角为[6,9]点,因此,初始位置S为[3,0],目标位置G为[3,7]:


START = [3, 0] #初始位置
GOAL = [3, 7] #目标位置


对于每个动作,都会更新智能体的位置,即更新环境。因为智能体有四个动作(上下左右),我们分别考虑他们的值。我们用i表示当前行,j表示当前列。


不管向哪里移动,都会受到风的影响,即:


i-WIND[i]


向上移动时,会修改行,使行数-1。最上面也就是索引为0的行。不能超出边界,因此向上移动一次的位置从[i,j]变为:


[max(i - 1 - WIND[j], 0), j]


向下移动时,会修改行,使行数+1。最上面也就是索引为0的行,最下面,也就是索引为HEIGHT - 1。不能超出边界,因此向上移动一次的位置从[i,j]变为:


[max(min(i + 1 - WIND[j], HEIGHT - 1), 0), j]


向左移动时,会修改列,使列数-1。最左面也就是索引为0的列。不能超出边界,因此向左移动一次的位置从[i,j]变为:


[max(i - WIND[j], 0), max(j - 1, 0)]


向右移动时,会修改列,使列数+1。最右面也就是索引为WIDTH- 1的列。不能超出边界,因此向上移动一次的位置从[i,j]变为:


[max(i - WIND[j], 0), min(j + 1, WIDTH - 1)]


全部代码如下:


# 移动
def move(state, action):
i, j = state
if action == ACTION_UP:
return [max(i - 1 - WIND[j], 0), j] #纵向上:向上移动一格,也要受到风的影响,最小是最上面,即0。横向上:不变
elif action == ACTION_DOWN:
return [max(min(i + 1 - WIND[j], HEIGHT - 1), 0), j] #纵向上:向下移动一格,也要受到风的影响,最小是最上面,即0;最大是最下面,即高度-1。横向上:不变
elif action == ACTION_LEFT:
return [max(i - WIND[j], 0), max(j - 1, 0)] #纵向上:受到风的影响,最小是最上面,即0。横向上:向左运动,最小是最左面,即0
else:
return [max(i - WIND[j], 0), min(j + 1, WIDTH - 1)] #纵向上:受到风的影响,最小是最上面,即0。横向上:向右运动,最大是最右面,即宽度-1


3、Sarsa算法


然后我们考虑执行程序,实现Sarsa算法。


在该算法中,我们根据前面自己写的流程描述完成代码:


1、初始化状态
2、在状态中,基于Q并使用贪婪策略选择动作
3、循环执行直到终止:
(1)更新状态
(2)在状态中,基于Q并使用贪婪策略更新动作
(3)使用公式更新Q
(4)将新状态与新动作存为当前状态与当前动作


从上面的流程中,我们需要做如下预备工作:


1、实现贪婪策略算法并设置初始ε
2、初始化Q计算中的参数α与γ


对于贪婪策略算法,我们设  ,  -贪婪算法一般采用二项分布来实现,我们可以使用numpy中的函数:


np.random.binomial(n, p, size=None)


其中参数和函数返回值的含义如下:


n:表示进行了n次实验
p:表示试验中事件成功的概率
size:表示试验做了多少组,如果默认不写,则为None,进行了一组试验
返回值:表示事件A发生的次数。


对于我们来说,我们只用:


np.random.binomial(1, eps)


也就是说,进行一次试验,事件(随机选择动作,而非贪婪选择最大值的动作)的概率为eps。


返回值只可能是0或者1,如果返回1,说明这一次试验的结果就是随机选择动作,我们从所有动作中随机选择即可;如果返回0,说明这次试验的结果就是贪婪选择最大值的动作。


在我们的案例中,我们的函数如下:


import numpy as np

EPSILON = 0.1

def epsilon_greedy(eps, q_value, state):
if np.random.binomial(1, eps) == 1:
action = np.random.choice(ACTIONS)
else:
values_ = q_value[state[0], state[1], :]
action = np.random.choice([action_ for action_, value_ in enumerate(values_) if value_ == np.max(values_)])
return action


然后我们初始化α和γ,由于我们这个案例中,每操作一次,奖励都是-1,走哪个位置,没有先后的必然意义,也不存在比较泛化的强化学习场景中的奖励折扣问题,因此我们初始化为:


ALPHA = 0.5
GAMMA = 1


γ就可以直接省略不写了。


然后我们就可以根据前面的流程写Sarsa算法了:


# Sarsa算法
def Sarsa(q_value):
state = START # 初始化状态
action = epsilon_greedy(EPSILON, q_value, state) # 在状态中,基于Q并使用贪婪策略选择动作

while state != GOAL: #循环执行直到终止
next_state = move(state, action) # 更新状态
next_action = epsilon_greedy(EPSILON, q_value, state) # 在状态中,基于Q并使用贪婪策略更新动作

# 使用公式更新Q
q_value[state[0], state[1], action] += \
ALPHA * (REWARD + q_value[next_state[0], next_state[1], next_action] -
q_value[state[0], state[1], action])
state = next_state #将新状态与新动作存为当前状态与当前动作
action = next_action


4、策略优化


然后我们就可以使用Sarsa优化策略,以做最优控制:


首先我们要初始化Q表格的值为0,与行、列、动作三个参数有关,然后我们需要设置回合限制,防止陷入死循环(设置太小可能会找不到最优解,甚至找不到正确解):


q_value = np.zeros((HEIGHT, WIDTH, 4))
episode_limit = 5000
ep = 0


然后在回合次数中,不断执行Sarsa算法优化策略:


while ep < episode_limit:
Sarsa(q_value)
ep += 1


然后我们就可以将优化的策略保存下来:


optimal_policy = []
for i in range(0, HEIGHT):
optimal_policy.append([])
for j in range(0, WIDTH):
if [i, j] == GOAL:
optimal_policy[-1].append('G')
continue
bestAction = np.argmax(q_value[i, j, :])
if bestAction == ACTION_UP:
optimal_policy[-1].append('U')
elif bestAction == ACTION_DOWN:
optimal_policy[-1].append('D')
elif bestAction == ACTION_LEFT:
optimal_policy[-1].append('L')
elif bestAction == ACTION_RIGHT:
optimal_policy[-1].append('R')
print('Optimal policy is:')
for row in optimal_policy:
print(row)


输出的结果为:


['D', 'R', 'U', 'R', 'R', 'R', 'R', 'R', 'R', 'D']
['R', 'R', 'R', 'R', 'U', 'R', 'R', 'D', 'R', 'D']
['R', 'D', 'U', 'R', 'R', 'R', 'R', 'U', 'R', 'D']
['R', 'R', 'R', 'R', 'R', 'R', 'R', 'G', 'R', 'D']
['D', 'D', 'R', 'U', 'R', 'R', 'U', 'D', 'L', 'L']
['R', 'R', 'R', 'R', 'R', 'U', 'U', 'D', 'D', 'D']
['R', 'R', 'R', 'R', 'U', 'U', 'U', 'U', 'U', 'L']


从起点开始到达目标位置G的最优路径在下图中:


DRL | 07 Sarsa: On-policy 时序差分控制方法_强化学习_05


在本公众号中回复“Sarsa”,可以获取完整代码。


4

说在后面的话



讲完了这些内容,我们能够对Sarsa有比较清晰的认识了,也能通过一个小案例,加深对Sarsa的理解。下一篇文章,我们会讲解另外一个时序差分控制方法,即Off-policy 时序差分控制Q-learning算法。


让我们继续学习后面的内容吧!


DRL | 07 Sarsa: On-policy 时序差分控制方法_初始化_06


DRL | 07 Sarsa: On-policy 时序差分控制方法_强化学习_07

AI与区块链技术

生活不只有诗和远方,还有人工智能和区块链。本公众号与大家分享以数学为基础的相关前沿科学,包括人工智能与区块链及相关应用技术(如机器学习,深度学习,视觉,密码学,区块链等)分享数学学习知识,信息学及编程技术等,欢迎有梦想的你与我一同交流。

233篇原创内容


公众号

DRL | 07 Sarsa: On-policy 时序差分控制方法_强化学习_08

举报

相关推荐

0 条评论