经验首页 前端设计 程序设计 Java相关 移动开发 数据库/运维 软件/图像 大数据/云计算 其他经验
当前位置:技术经验 » 程序设计 » Python3 » 查看文章
策略梯度玩 cartpole 游戏,强化学习代替PID算法控制平衡杆
来源:cnblogs  作者:高颜值的殺生丸  时间:2024/5/13 8:53:53  对本文有异议

 

cartpole游戏,车上顶着一个自由摆动的杆子,实现杆子的平衡,杆子每次倒向一端车就开始移动让杆子保持动态直立的状态,策略函数使用一个两层的简单神经网络,输入状态有4个,车位置,车速度,杆角度,杆速度,输出action为左移动或右移动,输入状态发现至少要给3个才能稳定一会儿,给2个完全学不明白,给4个能学到很稳定的policy

 

 

策略梯度实现代码,使用torch实现一个简单的神经网络

  1. import gym
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. import pygame
  6. import sys
  7. from collections import deque
  8. import numpy as np
  9.  
  10. # 策略网络定义
  11. class PolicyNetwork(nn.Module):
  12. def __init__(self):
  13. super(PolicyNetwork, self).__init__()
  14. self.fc = nn.Sequential(
  15. nn.Linear(4, 10), # 4个状态输入,128个隐藏单元
  16. nn.Tanh(),
  17. nn.Linear(10, 2), # 输出2个动作的概率
  18. nn.Softmax(dim=-1)
  19. )
  20.  
  21. def forward(self, x):
  22. # print(x) 车位置 车速度 杆角度 杆速度
  23. selected_values = x[:, [0,1,2,3]] #只使用车位置和杆角度
  24. return self.fc(selected_values)
  25.  
  26. # 训练函数
  27. def train(policy_net, optimizer, trajectories):
  28. policy_net.zero_grad()
  29. loss = 0
  30. print(trajectories[0])
  31. for trajectory in trajectories:
  32. # if trajectory["returns"] > 90:
  33. # returns = torch.tensor(trajectory["returns"]).float()
  34. # else:
  35. returns = torch.tensor(trajectory["returns"]).float() - torch.tensor(trajectory["step_mean_reward"]).float()
  36. # print(f"获得奖励{returns}")
  37. log_probs = trajectory["log_prob"]
  38. loss += -(log_probs * returns).sum() # 计算策略梯度损失
  39. loss.backward()
  40. optimizer.step()
  41. return loss.item()
  42.  
  43. # 主函数
  44. def main():
  45. env = gym.make('CartPole-v1')
  46. policy_net = PolicyNetwork()
  47. optimizer = optim.Adam(policy_net.parameters(), lr=0.01)
  48.  
  49. print(env.action_space)
  50. print(env.observation_space)
  51. pygame.init()
  52. screen = pygame.display.set_mode((600, 400))
  53. clock = pygame.time.Clock()
  54.  
  55. rewards_one_episode= []
  56. for episode in range(10000):
  57. state = env.reset()
  58. done = False
  59. trajectories = []
  60. state = state[0]
  61. step = 0
  62. torch.save(policy_net, 'policy_net_full.pth')
  63. while not done:
  64. state_tensor = torch.tensor(state).float().unsqueeze(0)
  65. probs = policy_net(state_tensor)
  66. action = torch.distributions.Categorical(probs).sample().item()
  67. log_prob = torch.log(probs.squeeze(0)[action])
  68. next_state, reward, done, _,_ = env.step(action)
  69.  
  70. # print(episode)
  71. trajectories.append({"state": state, "action": action, "reward": reward, "log_prob": log_prob})
  72. state = next_state
  73.  
  74. for event in pygame.event.get():
  75. if event.type == pygame.QUIT:
  76. pygame.quit()
  77. sys.exit()
  78. step +=1
  79. # 绘制环境状态
  80. if rewards_one_episode and rewards_one_episode[-1] >99:
  81. screen.fill((255, 255, 255))
  82. cart_x = int(state[0] * 100 + 300)
  83. pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30))
  84. # print(state)
  85. pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 2)
  86. pygame.display.flip()
  87. clock.tick(200)
  88.  
  89. print(f"第{episode}回合",f"运行{step}步后挂了")
  90. # 为策略梯度计算累积回报
  91. returns = 0
  92. for traj in reversed(trajectories):
  93. returns = traj["reward"] + 0.99 * returns
  94. traj["returns"] = returns
  95. if rewards_one_episode:
  96. # print(rewards_one_episode[:10])
  97. traj["step_mean_reward"] = np.mean(rewards_one_episode[-10:])
  98. else:
  99. traj["step_mean_reward"] = 0
  100. rewards_one_episode.append(returns)
  101. # print(rewards_one_episode[:10])
  102. train(policy_net, optimizer, trajectories)
  103.  
  104. def play():
  105.  
  106. env = gym.make('CartPole-v1')
  107. policy_net = PolicyNetwork()
  108. pygame.init()
  109. screen = pygame.display.set_mode((600, 400))
  110. clock = pygame.time.Clock()
  111.  
  112. state = env.reset()
  113. done = False
  114. trajectories = deque()
  115. state = state[0]
  116. step = 0
  117. policy_net = torch.load('policy_net_full.pth')
  118. while not done:
  119. state_tensor = torch.tensor(state).float().unsqueeze(0)
  120. probs = policy_net(state_tensor)
  121. action = torch.distributions.Categorical(probs).sample().item()
  122. log_prob = torch.log(probs.squeeze(0)[action])
  123. next_state, reward, done, _,_ = env.step(action)
  124.  
  125. # print(episode)
  126. trajectories.append({"state": state, "action": action, "reward": reward, "log_prob": log_prob})
  127. state = next_state
  128.  
  129. for event in pygame.event.get():
  130. if event.type == pygame.QUIT:
  131. pygame.quit()
  132. sys.exit()
  133.  
  134. # 绘制环境状态
  135. screen.fill((255, 255, 255))
  136. cart_x = int(state[0] * 100 + 300)
  137. pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30))
  138. # print(state)
  139. pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 2)
  140. pygame.display.flip()
  141. clock.tick(60)
  142. step +=1
  143.  
  144. print(f"运行{step}步后挂了")
  145.  
  146.  
  147.  
  148. if __name__ == '__main__':
  149. main() #训练
  150. # play() #推理

  运行效果,训练过程不是很稳定,有时候学很多轮次也学不明白,有时侯只需要几十次就可以学明白了

 

原文链接:https://www.cnblogs.com/LiuXinyu12378/p/18187947

 友情链接:直通硅谷  点职佳  北美留学生论坛

本站QQ群:前端 618073944 | Java 606181507 | Python 626812652 | C/C++ 612253063 | 微信 634508462 | 苹果 692586424 | C#/.net 182808419 | PHP 305140648 | 运维 608723728

W3xue 的所有内容仅供测试,对任何法律问题及风险不承担任何责任。通过使用本站内容随之而来的风险与本站无关。
关于我们  |  意见建议  |  捐助我们  |  报错有奖  |  广告合作、友情链接(目前9元/月)请联系QQ:27243702 沸活量
皖ICP备17017327号-2 皖公网安备34020702000426号