-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdqn.py
26 lines (20 loc) · 924 Bytes
/
dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn
import layers
class DQN(nn.Module):
def __init__(self, obs_space_size, action_space_size,
num_hidden, num_layers, activation='ReLU'):
super().__init__()
layer_list = []
for i_layer in range(num_layers):
if len(layer_list) == 0:
layer_list.append(nn.Linear(obs_space_size + action_space_size, num_hidden))
layer_list.append(layers.activation_dict[activation]())
else:
layer_list.append(nn.Linear(num_hidden, num_hidden))
layer_list.append(layers.activation_dict[activation]())
layer_list.append((nn.Linear(num_hidden, 1)))
self.linear_activation_stack = nn.Sequential(*layer_list)
def forward(self, x):
logits = self.linear_activation_stack(x).flatten()
return logits