0
点赞
收藏
分享

微信扫一扫

小琳AI课堂 强化学习策略优化:深入解析SARSA算法及其Python实现


大家好,这里是小琳AI课堂。今天我们来深入探讨SARSA算法,这是一种强化学习中的策略优化方法。🚀
SARSA(State-Action-Reward-State-Action)算法是一种基于值函数的方法,与Q-learning类似,但它采用了一种不同的策略。SARSA使用当前策略来选择下一个动作,而不是选择具有最大Q值的动作。这意味着SARSA是一种“在线策略迭代”方法,它学习的是与数据生成策略相同的策略。
SARSA算法的基本步骤包括初始化Q表、选择初始动作、循环执行动作并根据当前策略选择下一个动作,最后更新Q表。
SARSA的更新规则是基于以下公式:
[ Q(s_t, a_t) \leftarrow Q(s_t, a_t) + \alpha \left[ R_{t+1} + \gamma Q(s_{t+1}, a_{t+1}) - Q(s_t, a_t) \right] ]
这个公式考虑了当前状态下的动作价值、执行该动作后的即时奖励、在新状态下根据当前策略选择的动作价值,以及折扣因子。
下面是一个使用OpenAI Gym的CartPole-v1环境的SARSA算法的Python示例。这个环境的目标是通过平衡杆子来使小车保持在中心位置。

import gym
import numpy as np

# 创建环境
env = gym.make('CartPole-v1')

# 参数设置
num_episodes = 1000
max_steps_per_episode = 200
learning_rate = 0.1
discount_rate = 0.95
exploration_rate = 1.0
min_exploration_rate = 0.01
exploration_decay_rate = 0.001

# 初始化 Q 表
num_states = (env.observation_space.high - env.observation_space.low) * \
             np.array([10, 100, 10, 50])
num_states = np.round(num_states, 0).astype(int) + 1
q_table = np.zeros(shape=(tuple(num_states), env.action_space.n))

# 训练智能体
for episode in range(num_episodes):
    state = env.reset()[0]  # 重置环境并获取初始状态
    state = np.round(state, decimals=0).astype(int)
    
    done = False
    t = 0
    
    # 选择初始动作
    if np.random.uniform(0, 1) < exploration_rate:
        action = env.action_space.sample()  # 探索
    else:
        action = np.argmax(q_table[state])  # 利用
    
    while not done and t < max_steps_per_episode:
        # 执行动作并观察结果
        next_state, reward, done, _, _ = env.step(action)
        next_state = np.round(next_state, decimals=0).astype(int)
        
        # 选择下一个动作
        if np.random.uniform(0, 1) < exploration_rate:
            next_action = env.action_space.sample()  # 探索
        else:
            next_action = np.argmax(q_table[next_state])  # 利用
        
        # 更新 Q 表
        old_value = q_table[state][action]
        next_max = q_table[next_state][next_action]
        new_value = (1 - learning_rate) * old_value + learning_rate * (reward + discount_rate * next_max)
        q_table[state][action] = new_value
        
        # 设置新的状态和动作
        state = next_state
        action = next_action
        
        # 衰减探索率
        exploration_rate = min_exploration_rate + \
                           (max_exploration_rate - min_exploration_rate) * np.exp(-exploration_decay_rate * episode)
        
        t += 1

# 测试智能体
test_episodes = 10
for episode in range(test_episodes):
    state = env.reset()[0]
    state = np.round(state, decimals=0).astype(int)
    done = False
    t = 0
    
    while not done and t < max_steps_per_episode:
        env.render()  # 显示图形界面
        action = np.argmax(q_table[state])
        state, reward, done, _, _ = env.step(action)
        state = np.round(state, decimals=0).astype(int)
        t += 1

env.close()

这个示例展示了如何使用SARSA算法在一个简单的环境中训练智能体,并展示其学习效果。希望这个示例能帮助你更好地理解SARSA算法及其在强化学习中的应用。
如果你有任何问题或想法,欢迎在评论区留言分享!👇
本期的小琳AI课堂就到这里,希望你喜欢今天的内容!下期见!👋


举报

相关推荐

0 条评论