- import gym
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import pygame
- import sys
- from collections import deque
- import numpy as np
-
- # 策略网络定义
- class PolicyNetwork(nn.Module):
- def __init__(self):
- super(PolicyNetwork, self).__init__()
- self.fc = nn.Sequential(
- nn.Linear(4, 10), # 4个状态输入,128个隐藏单元
- nn.Tanh(),
- nn.Linear(10, 2), # 输出2个动作的概率
- nn.Softmax(dim=-1)
- )
-
- def forward(self, x):
- # print(x) 车位置 车速度 杆角度 杆速度
- selected_values = x[:, [0,1,2,3]] #只使用车位置和杆角度
- return self.fc(selected_values)
-
- # 训练函数
- def train(policy_net, optimizer, trajectories):
- policy_net.zero_grad()
- loss = 0
- print(trajectories[0])
- for trajectory in trajectories:
-
- # if trajectory["returns"] > 90:
- # returns = torch.tensor(trajectory["returns"]).float()
- # else:
- returns = torch.tensor(trajectory["returns"]).float() - torch.tensor(trajectory["step_mean_reward"]).float()
- # print(f"获得奖励{returns}")
- log_probs = trajectory["log_prob"]
- loss += -(log_probs * returns).sum() # 计算策略梯度损失
- loss.backward()
- optimizer.step()
- return loss.item()
-
- # 主函数
- def main():
- env = gym.make('CartPole-v1')
- policy_net = PolicyNetwork()
- optimizer = optim.Adam(policy_net.parameters(), lr=0.01)
-
- print(env.action_space)
- print(env.observation_space)
- pygame.init()
- screen = pygame.display.set_mode((600, 400))
- clock = pygame.time.Clock()
-
- rewards_one_episode= []
- for episode in range(10000):
-
- state = env.reset()
- done = False
- trajectories = []
- state = state[0]
- step = 0
- torch.save(policy_net, 'policy_net_full.pth')
- while not done:
- state_tensor = torch.tensor(state).float().unsqueeze(0)
- probs = policy_net(state_tensor)
- action = torch.distributions.Categorical(probs).sample().item()
- log_prob = torch.log(probs.squeeze(0)[action])
- next_state, reward, done, _,_ = env.step(action)
-
- # print(episode)
- trajectories.append({"state": state, "action": action, "reward": reward, "log_prob": log_prob})
- state = next_state
-
- for event in pygame.event.get():
- if event.type == pygame.QUIT:
- pygame.quit()
- sys.exit()
- step +=1
-
- # 绘制环境状态
- if rewards_one_episode and rewards_one_episode[-1] >99:
- screen.fill((255, 255, 255))
- cart_x = int(state[0] * 100 + 300)
- pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30))
- # print(state)
- pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 2)
- pygame.display.flip()
- clock.tick(200)
-
-
- print(f"第{episode}回合",f"运行{step}步后挂了")
- # 为策略梯度计算累积回报
- returns = 0
-
-
- for traj in reversed(trajectories):
- returns = traj["reward"] + 0.99 * returns
- traj["returns"] = returns
- if rewards_one_episode:
- # print(rewards_one_episode[:10])
- traj["step_mean_reward"] = np.mean(rewards_one_episode[-10:])
- else:
- traj["step_mean_reward"] = 0
- rewards_one_episode.append(returns)
- # print(rewards_one_episode[:10])
- train(policy_net, optimizer, trajectories)
-
- def play():
-
- env = gym.make('CartPole-v1')
- policy_net = PolicyNetwork()
- pygame.init()
- screen = pygame.display.set_mode((600, 400))
- clock = pygame.time.Clock()
-
- state = env.reset()
- done = False
- trajectories = deque()
- state = state[0]
- step = 0
- policy_net = torch.load('policy_net_full.pth')
- while not done:
- state_tensor = torch.tensor(state).float().unsqueeze(0)
- probs = policy_net(state_tensor)
- action = torch.distributions.Categorical(probs).sample().item()
- log_prob = torch.log(probs.squeeze(0)[action])
- next_state, reward, done, _,_ = env.step(action)
-
- # print(episode)
- trajectories.append({"state": state, "action": action, "reward": reward, "log_prob": log_prob})
- state = next_state
-
- for event in pygame.event.get():
- if event.type == pygame.QUIT:
- pygame.quit()
- sys.exit()
-
-
- # 绘制环境状态
- screen.fill((255, 255, 255))
- cart_x = int(state[0] * 100 + 300)
- pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30))
- # print(state)
- pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 2)
- pygame.display.flip()
- clock.tick(60)
- step +=1
-
- print(f"运行{step}步后挂了")
-
-
-
- if __name__ == '__main__':
- main() #训练
- # play() #推理