大家好,这里是小琳AI课堂。今天我们来深入探讨SARSA算法,这是一种强化学习中的策略优化方法。🚀
SARSA(State-Action-Reward-State-Action)算法是一种基于值函数的方法,与Q-learning类似,但它采用了一种不同的策略。SARSA使用当前策略来选择下一个动作,而不是选择具有最大Q值的动作。这意味着SARSA是一种“在线策略迭代”方法,它学习的是与数据生成策略相同的策略。
SARSA算法的基本步骤包括初始化Q表、选择初始动作、循环执行动作并根据当前策略选择下一个动作,最后更新Q表。
SARSA的更新规则是基于以下公式:
[ Q(s_t, a_t) \leftarrow Q(s_t, a_t) + \alpha \left[ R_{t+1} + \gamma Q(s_{t+1}, a_{t+1}) - Q(s_t, a_t) \right] ]
这个公式考虑了当前状态下的动作价值、执行该动作后的即时奖励、在新状态下根据当前策略选择的动作价值,以及折扣因子。
下面是一个使用OpenAI Gym的CartPole-v1
环境的SARSA算法的Python示例。这个环境的目标是通过平衡杆子来使小车保持在中心位置。
import gym
import numpy as np
# 创建环境
env = gym.make('CartPole-v1')
# 参数设置
num_episodes = 1000
max_steps_per_episode = 200
learning_rate = 0.1
discount_rate = 0.95
exploration_rate = 1.0
min_exploration_rate = 0.01
exploration_decay_rate = 0.001
# 初始化 Q 表
num_states = (env.observation_space.high - env.observation_space.low) * \
np.array([10, 100, 10, 50])
num_states = np.round(num_states, 0).astype(int) + 1
q_table = np.zeros(shape=(tuple(num_states), env.action_space.n))
# 训练智能体
for episode in range(num_episodes):
state = env.reset()[0] # 重置环境并获取初始状态
state = np.round(state, decimals=0).astype(int)
done = False
t = 0
# 选择初始动作
if np.random.uniform(0, 1) < exploration_rate:
action = env.action_space.sample() # 探索
else:
action = np.argmax(q_table[state]) # 利用
while not done and t < max_steps_per_episode:
# 执行动作并观察结果
next_state, reward, done, _, _ = env.step(action)
next_state = np.round(next_state, decimals=0).astype(int)
# 选择下一个动作
if np.random.uniform(0, 1) < exploration_rate:
next_action = env.action_space.sample() # 探索
else:
next_action = np.argmax(q_table[next_state]) # 利用
# 更新 Q 表
old_value = q_table[state][action]
next_max = q_table[next_state][next_action]
new_value = (1 - learning_rate) * old_value + learning_rate * (reward + discount_rate * next_max)
q_table[state][action] = new_value
# 设置新的状态和动作
state = next_state
action = next_action
# 衰减探索率
exploration_rate = min_exploration_rate + \
(max_exploration_rate - min_exploration_rate) * np.exp(-exploration_decay_rate * episode)
t += 1
# 测试智能体
test_episodes = 10
for episode in range(test_episodes):
state = env.reset()[0]
state = np.round(state, decimals=0).astype(int)
done = False
t = 0
while not done and t < max_steps_per_episode:
env.render() # 显示图形界面
action = np.argmax(q_table[state])
state, reward, done, _, _ = env.step(action)
state = np.round(state, decimals=0).astype(int)
t += 1
env.close()
这个示例展示了如何使用SARSA算法在一个简单的环境中训练智能体,并展示其学习效果。希望这个示例能帮助你更好地理解SARSA算法及其在强化学习中的应用。
如果你有任何问题或想法,欢迎在评论区留言分享!👇
本期的小琳AI课堂就到这里,希望你喜欢今天的内容!下期见!👋