微信扫码
与创始人交个朋友
我要投稿
在强化学习中,策略梯度方法通过直接优化策略来最大化累积奖励。传统的策略梯度方法,如REINFORCE,存在高方差和收敛速度慢的问题。为了解决这些问题,Schulman等人提出了近端策略优化算法(Proximal Policy Optimization,PPO),它在更新策略时引入了信赖域约束,既保证了策略的更新幅度不过大,又简化了计算过程,被广泛应用于各种强化学习任务中。
PPO算法的核心思想是通过限制新旧策略之间的变化,防止策略更新过度。具体来说,PPO通过以下目标函数来更新策略:
其中:
优势函数 可以通过广义优势估计(Generalized Advantage Estimation,GAE)来计算:
其中,TD残差 定义为:
是折扣因子, 是用于平衡偏差和方差的超参数。
PPO的策略更新通过最大化 来实现。由于引入了 操作,损失函数对 的变化在 范围之外不再敏感,从而限制了每次更新的步幅。
除了策略网络,PPO还使用价值网络来估计状态值函数 ,其损失函数为:
其中, 是对真实价值的估计,例如使用TD目标:
综合考虑策略损失和价值函数损失,以及可能的熵正则项,PPO的总损失函数为:
其中:
为了更好地理解PPO算法,我们在经典的CartPole-v1环境上进行了实验。该环境的目标是控制小车移动,以保持竖立的杆子不倒下。
以下是PPO算法在CartPole-v1环境上的部分实现代码:
class PPO:
'''PPO算法'''
def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma,
lmbda, epsilon, epochs, device):
self.action_dim = action_dim
self.actor_critic = ActorCritic(state_dim, hidden_dim, action_dim).to(device)
self.actor_optimizer = optim.Adam(self.actor_critic.actor_parameters(), lr=actor_lr)
self.critic_optimizer = optim.Adam(self.actor_critic.critic_parameters(), lr=critic_lr)
self.gamma = gamma # 折扣因子
self.lmbda = lmbda # GAE参数
self.epsilon = epsilon # PPO截断范围
self.epochs = epochs # PPO的更新次数
self.device = device
def take_action(self, state):
'''根据策略网络选择动作'''
state = torch.tensor([state], dtype=torch.float).to(self.device)
with torch.no_grad():
action_probs, _ = self.actor_critic(state)
dist = torch.distributions.Categorical(action_probs)
action = dist.sample()
return action.item()
def update(self, transition_dict):
'''更新策略网络和价值网络'''
states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
actions = torch.tensor(transition_dict['actions']).view(-1).to(self.device)
rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)
next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)
dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(self.device)
# 计算TD误差和优势函数
_, state_values = self.actor_critic(states)
_, next_state_values = self.actor_critic(next_states)
td_target = rewards + self.gamma * next_state_values * (1 - dones)
delta = td_target - state_values
delta = delta.detach().cpu().numpy()
# Generalized Advantage Estimation (GAE)
advantage_list = []
advantage = 0.0
for delta_t in delta[::-1]:
advantage = self.gamma * self.lmbda * advantage + delta_t[0]
advantage_list.append([advantage])
advantage_list.reverse()
advantages = torch.tensor(advantage_list, dtype=torch.float).to(self.device)
# 计算旧策略的log概率
with torch.no_grad():
action_probs_old, _ = self.actor_critic(states)
dist_old = torch.distributions.Categorical(action_probs_old)
log_probs_old = dist_old.log_prob(actions)
# 更新策略网络和价值网络
for _ in range(self.epochs):
action_probs, state_values = self.actor_critic(states)
dist = torch.distributions.Categorical(action_probs)
log_probs = dist.log_prob(actions)
ratio = torch.exp(log_probs - log_probs_old)
surr1 = ratio * advantages.squeeze()
surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantages.squeeze()
actor_loss = -torch.mean(torch.min(surr1, surr2))
critic_loss = F.mse_loss(state_values, td_target.detach())
# 更新策略网络
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# 更新价值网络
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
Iteration 1: 100%|██████████| 30/30 [00:00<00:00, 66.19it/s, Episode=30/300, Average Return=10.00]
Iteration 2: 100%|██████████| 30/30 [00:00<00:00, 36.67it/s, Episode=60/300, Average Return=162.90]
Iteration 3: 100%|██████████| 30/30 [00:01<00:00, 24.94it/s, Episode=90/300, Average Return=278.70]
Iteration 4: 100%|██████████| 30/30 [00:01<00:00, 19.59it/s, Episode=120/300, Average Return=287.80]
Iteration 5: 100%|██████████| 30/30 [00:01<00:00, 17.57it/s, Episode=150/300, Average Return=240.70]
Iteration 6: 100%|██████████| 30/30 [00:01<00:00, 21.10it/s, Episode=180/300, Average Return=354.60]
Iteration 7: 100%|██████████| 30/30 [00:02<00:00, 12.90it/s, Episode=210/300, Average Return=450.50]
Iteration 8: 100%|██████████| 30/30 [00:02<00:00, 11.59it/s, Episode=240/300, Average Return=500.00]
Iteration 9: 100%|██████████| 30/30 [00:02<00:00, 11.52it/s, Episode=270/300, Average Return=475.50]
Iteration 10: 100%|██████████| 30/30 [00:02<00:00, 11.31it/s, Episode=300/300, Average Return=500.00]
运行上述代码,可以观察到在训练过程中,智能体的平均回报逐渐提高,最终稳定在较高水平。这表明PPO算法有效地学习到了保持杆子平衡的策略。
从学习曲线可以看出,经过大约200个回合的训练,智能体的表现达到了环境的最高分。这验证了PPO算法在处理连续动作空间和策略优化问题上的有效性。
注:由于完整代码过长,请关注公众号回复“交流”领取。
PPO算法通过引入概率比率的截断和优势函数的估计,实现了高效稳定的策略更新。在CartPole-v1环境上的实验表明,PPO能够快速收敛到最优策略,具有较好的性能和稳定性。由于其简单高效的特点,PPO在强化学习领域得到了广泛的应用和认可。
53AI,企业落地应用大模型首选服务商
产品:大模型应用平台+智能体定制开发+落地咨询服务
承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2025-01-09
解码通用 AI Agent:七步构建你的智能系统
2025-01-08
dify案例分享-基于文本模型实现Fine-tune 语料构造工作流
2025-01-08
架构师必备LLM推理优化全解析:Nvidia分享的实用技巧,简单易懂!
2025-01-06
模型Prompt调优的实用技巧与经验分享
2025-01-06
大模型推理框架:Ollama和vLLM到底应该选哪个?
2025-01-06
大模型高效训练一体框架 LLaMA Factory
2025-01-06
增强大模型的推理能力:从思维链到连续思维链(上)
2025-01-06
LLM之模型评估:情感评估/EQ评估/幻觉评估等
2024-09-18
2024-07-11
2024-07-11
2024-07-26
2024-07-09
2024-06-11
2024-10-20
2024-07-20
2024-07-23
2024-07-12