-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathA3C_CartPole.py
79 lines (68 loc) · 2.3 KB
/
A3C_CartPole.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# %%
import numpy as np
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import multiprocessing
# %%
# 定义网络结构
class ActorCritic(nn.Module):
def __init__(self, input_dim, output_dim):
super(ActorCritic, self).__init__()
self.actor = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, output_dim),
nn.Softmax(dim=-1)
)
self.critic = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, 1)
)
def forward(self, state):
probs = self.actor(state)
value = self.critic(state)
return probs, value
# %%
# A3C更新函数
def train(global_model, optimizer, state, action, reward, next_state, done, gamma=0.99):
state = torch.FloatTensor(state)
next_state = torch.FloatTensor(next_state)
reward = torch.FloatTensor([reward])
action = torch.LongTensor([action])
probs, value = global_model(state)
_, next_value = global_model(next_state)
td_target = reward + gamma * next_value * (1 - done)
delta = td_target - value
actor_loss = -torch.log(probs[action]) * delta.detach()
critic_loss = delta ** 2
loss = actor_loss + critic_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 工作线程
def worker(global_model, optimizer, worker_id):
env = gym.make("CartPole-v1")
state, info = env.reset()
while True:
action_probs, _ = global_model(torch.FloatTensor(state))
action = np.random.choice(env.action_space.n, p=action_probs.detach().numpy())
next_state, reward, done, _, info = env.step(action)
train(global_model, optimizer, state, action, reward, next_state, done)
state = next_state
if done:
state, _ = env.reset()
# %%
if __name__ == "__main__":
global_model = ActorCritic(4, 2)
global_model.share_memory() # 允许多进程共享模型参数
optimizer = optim.Adam(global_model.parameters(), lr=0.001)
processes = []
for i in range(multiprocessing.cpu_count()): # 使用所有可用的CPU核心
p = multiprocessing.Process(target=worker, args=(global_model, optimizer, i))
p.start()
processes.append(p)
for p in processes:
p.join()