Skip to content

Commit

Permalink
adding dqn as agent
Browse files Browse the repository at this point in the history
  • Loading branch information
RajGhugare19 committed Mar 10, 2023
1 parent e51c92a commit 9f2ccde
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 87 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ MUJOCO_LOG.TXT
saved/

#virtual env
env_rl4chem/
env_*

#notebooks
.ipynb_checkpoints/
Expand All @@ -29,6 +29,7 @@ data/
#package dependencies
dockstring/
openbabel.git
oracle/

#local experiments
store_room/
Expand Down
13 changes: 5 additions & 8 deletions cfgs/config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
agent: 'sac'
agent: 'dqn'

#common
id: 'docking'
Expand Down Expand Up @@ -30,13 +30,10 @@ action_dtype:

#learning
gamma: 0.99
tau: 0.01
target_update_interval: 1
policy_update_interval: 2
lr: {'actor':3e-4, 'critic':1e-3, 'alpha':1e-3}

#exploration
entropy_coefficient: 0.1
tau: 1.0
update_interval: 10
target_update_interval: 500
lr: {'q':2.5e-4}

#hidden_dims and layers
hidden_dims: 256
Expand Down
73 changes: 0 additions & 73 deletions cfgs/mbac.yaml

This file was deleted.

108 changes: 108 additions & 0 deletions dqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as td
import numpy as np
import random
import wandb

import utils

class QNetwork(nn.Module):
def __init__(self, input_dims, hidden_dims, output_dims):
super().__init__()
self.network = nn.Sequential(
nn.Linear(input_dims, hidden_dims),
nn.ReLU(), nn.Linear(hidden_dims, hidden_dims), nn.LayerNorm(hidden_dims),
nn.ReLU(), nn.Linear(hidden_dims, output_dims))

def forward(self, x):
return self.network(x)

class DQNAgent:
def __init__(self, device, obs_dims, num_actions,
gamma, tau, update_interval, target_update_interval, lr, batch_size,
hidden_dims, wandb_log, log_interval):

self.device = device

#learning
self.gamma = gamma
self.tau = tau
self.update_interval = update_interval
self.target_update_interval = target_update_interval
self.batch_size = batch_size

#logging
self.wandb_log = wandb_log
self.log_interval = log_interval

self._init_networks(obs_dims, num_actions, hidden_dims)
self._init_optims(lr)

def get_action(self, obs, eval=False):
with torch.no_grad():
obs = torch.tensor(obs, dtype=torch.float32, device=self.device)
q_values = self.q(obs)
action = torch.argmax(q_values)
return action.cpu().numpy()

def _init_networks(self, obs_dims, num_actions, hidden_dims):
self.q = QNetwork(obs_dims, hidden_dims, num_actions).to(self.device)
self.q_target = QNetwork(obs_dims, hidden_dims, num_actions).to(self.device)
utils.hard_update(self.q_target, self.q)

def _init_optims(self, lr):
self.q_opt = torch.optim.Adam(self.q.parameters(), lr=lr["q"])

def get_save_dict(self):
return {
"q": self.q.state_dict(),
"q_target":self.q_target.state_dict(),
}

def load_save_dict(self, saved_dict):
self.q.load_state_dict(saved_dict["q"])
self.q_target.load_state_dict(saved_dict["q_target"])

def update(self, buffer, step):
metrics = dict()
if step % self.log_interval == 0 and self.wandb_log:
log = True
else:
log = False

if step % self.update_interval == 0:
state_batch, action_batch, reward_batch, next_state_batch, done_batch, time_batch = buffer.sample(self.batch_size)
state_batch = torch.tensor(state_batch, dtype=torch.float32, device=self.device)
next_state_batch = torch.tensor(next_state_batch, dtype=torch.float32, device=self.device)
action_batch = torch.tensor(action_batch, dtype=torch.long, device=self.device)
reward_batch = torch.tensor(reward_batch, dtype=torch.float32, device=self.device)
done_batch = torch.tensor(done_batch, dtype=torch.float32, device=self.device)
discount_batch = self.gamma*(1-done_batch)

with torch.no_grad():
target_max, _ = self.q_target(next_state_batch).max(dim=1)
td_target = reward_batch + self.gamma * target_max * discount_batch

old_val = self.q(state_batch).gather(1, action_batch).squeeze()

loss = F.mse_loss(td_target, old_val)
self.q_opt.zero_grad()
loss.backward()
self.q_opt.step()

if log:
metrics['mean_q_target'] = torch.mean(td_target).item()
metrics['max_reward'] = torch.max(reward_batch).item()
metrics['min_reward'] = torch.min(reward_batch).item()
metrics['variance_q_target'] = torch.var(td_target).item()
metrics['min_q_target'] = torch.min(td_target).item()
metrics['max_q_target'] = torch.max(td_target).item()
metrics['critic_loss'] = loss.item()

if step % self.target_update_interval == 0:
utils.soft_update(self.q_target, self.q, self.tau)

if log:
wandb.log(metrics, step=step)
22 changes: 17 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ def make_agent(env, device, cfg):
agent = SacAgent(device, obs_dims, num_actions, cfg.gamma, cfg.tau,
cfg.policy_update_interval, cfg.target_update_interval, cfg.lr, cfg.batch_size, cfg.entropy_coefficient,
cfg.hidden_dims, cfg.wandb_log, cfg.agent_log_interval)

elif cfg.agent == 'dqn':
from dqn import DQNAgent
agent = DQNAgent(device, obs_dims, num_actions, cfg.gamma, cfg.tau,
cfg.update_interval, cfg.target_update_interval, cfg.lr, cfg.batch_size,
cfg.hidden_dims, cfg.wandb_log, cfg.agent_log_interval)

else:
raise NotImplementedError
return agent, env_buffer, fresh_env_buffer, docking_buffer
Expand Down Expand Up @@ -97,14 +104,20 @@ def explore(cfg, train_env, env_buffer, fresh_env_buffer, docking_buffer):

return parallel_reward_batch

def collect_molecule(env, agent, fresh_env_buffer):
def collect_molecule(env, agent, fresh_env_buffer, train_step, num_train_steps):
state, done, t = env.reset(), False, 0
while not done:
action = agent.get_action(state)
epsilon = utils.linear_schedule(1, 0.05, 0.5 * num_train_steps, train_step+t)
if random.random() < epsilon:
action = np.random.randint(env.num_actions)
else:
action = agent.get_action(state)
next_state, reward, done, info = env.step(action)
fresh_env_buffer.push((state, action, reward, next_state, done, t))
t += 1
state = next_state

info['episode']['last epsilon'] = epsilon
return info['episode']

def train(cfg):
Expand All @@ -128,16 +141,15 @@ def train(cfg):
unique_molecule_counter = 0
while train_step < cfg.num_train_steps:


episode_info = collect_molecule(train_env, agent, fresh_env_buffer)
episode_info = collect_molecule(train_env, agent, fresh_env_buffer, train_step, cfg.num_train_steps)
molecule_counter += 1
if docking_buffer[episode_info['smiles']] is not None:
fresh_env_buffer.update_last_episode_reward(docking_buffer[episode_info['smiles']])
else:
docking_buffer[episode_info['smiles']] = 0
train_env._add_smiles_to_batch(episode_info['smiles'])
unique_molecule_counter += 1

for _ in range(episode_info['l']):
agent.update(env_buffer, train_step)
train_step += 1
Expand Down

0 comments on commit 9f2ccde

Please sign in to comment.