diff --git a/.gitignore b/.gitignore index 97e4948..cbeaa88 100644 --- a/.gitignore +++ b/.gitignore @@ -13,7 +13,7 @@ MUJOCO_LOG.TXT saved/ #virtual env -env_rl4chem/ +env_* #notebooks .ipynb_checkpoints/ @@ -29,6 +29,7 @@ data/ #package dependencies dockstring/ openbabel.git +oracle/ #local experiments store_room/ diff --git a/cfgs/config.yaml b/cfgs/config.yaml index 82074da..71c9766 100644 --- a/cfgs/config.yaml +++ b/cfgs/config.yaml @@ -1,4 +1,4 @@ -agent: 'sac' +agent: 'dqn' #common id: 'docking' @@ -30,13 +30,10 @@ action_dtype: #learning gamma: 0.99 -tau: 0.01 -target_update_interval: 1 -policy_update_interval: 2 -lr: {'actor':3e-4, 'critic':1e-3, 'alpha':1e-3} - -#exploration -entropy_coefficient: 0.1 +tau: 1.0 +update_interval: 10 +target_update_interval: 500 +lr: {'q':2.5e-4} #hidden_dims and layers hidden_dims: 256 diff --git a/cfgs/mbac.yaml b/cfgs/mbac.yaml deleted file mode 100644 index 21b5ec3..0000000 --- a/cfgs/mbac.yaml +++ /dev/null @@ -1,73 +0,0 @@ -#common -id: 'docking' -device: 'cuda' -seed: 1 - -#environment specific -target: 'fa7' -selfies_enc_type: 'index' -max_selfie_length: 22 -vina_program: 'qvina2' -temp_dir: 'tmp' -exhaustiveness: 1 -num_sub_proc: 12 -num_cpu_dock: 1 -num_modes: 10 -timeout_gen3d: 30 -timeout_dock: 100 - -#data -num_train_steps: 100000 -env_buffer_size: 100000 -explore_molecules: 250 -parallel_molecules: 250 -batch_size: 256 -obs_dtype: -action_dtype: - -#learning -lr: {'actor':1e-4, 'reward':1e-4, 'alpha':1e-3} -max_grad_norm: 0.5 -vocab_size: -pad_idx: - -#Actor -actor: - _target_: mbac.Actor - vocab_size: ${vocab_size} - pad_idx: ${pad_idx} - device: ${device} - output_size: ${vocab_size} - num_layers: 3 - hidden_size: 256 - embedding_size: 32 - -#Reward -reward: - _target_: mbac.Reward - vocab_size: ${vocab_size} - pad_idx: ${pad_idx} - device: ${device} - num_layers: 3 - hidden_size: 256 - embedding_size: 32 - -#evaluation -eval_episode_interval: 10000 -num_eval_episodes: 5 - -#logging -wandb_log: False -wandb_entity: 'raj19' -wandb_run_name: 'mbac-disc' -agent_log_interval: 100 - -#saving -save_snapshot: False -save_snapshot_interval: 100000 - -hydra: - run: - dir: ./local_exp/${now:%Y.%m.%d}/${now:%H.%M.%S}_${seed} - job: - chdir: False \ No newline at end of file diff --git a/dqn.py b/dqn.py new file mode 100644 index 0000000..6b980ec --- /dev/null +++ b/dqn.py @@ -0,0 +1,108 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributions as td +import numpy as np +import random +import wandb + +import utils + +class QNetwork(nn.Module): + def __init__(self, input_dims, hidden_dims, output_dims): + super().__init__() + self.network = nn.Sequential( + nn.Linear(input_dims, hidden_dims), + nn.ReLU(), nn.Linear(hidden_dims, hidden_dims), nn.LayerNorm(hidden_dims), + nn.ReLU(), nn.Linear(hidden_dims, output_dims)) + + def forward(self, x): + return self.network(x) + +class DQNAgent: + def __init__(self, device, obs_dims, num_actions, + gamma, tau, update_interval, target_update_interval, lr, batch_size, + hidden_dims, wandb_log, log_interval): + + self.device = device + + #learning + self.gamma = gamma + self.tau = tau + self.update_interval = update_interval + self.target_update_interval = target_update_interval + self.batch_size = batch_size + + #logging + self.wandb_log = wandb_log + self.log_interval = log_interval + + self._init_networks(obs_dims, num_actions, hidden_dims) + self._init_optims(lr) + + def get_action(self, obs, eval=False): + with torch.no_grad(): + obs = torch.tensor(obs, dtype=torch.float32, device=self.device) + q_values = self.q(obs) + action = torch.argmax(q_values) + return action.cpu().numpy() + + def _init_networks(self, obs_dims, num_actions, hidden_dims): + self.q = QNetwork(obs_dims, hidden_dims, num_actions).to(self.device) + self.q_target = QNetwork(obs_dims, hidden_dims, num_actions).to(self.device) + utils.hard_update(self.q_target, self.q) + + def _init_optims(self, lr): + self.q_opt = torch.optim.Adam(self.q.parameters(), lr=lr["q"]) + + def get_save_dict(self): + return { + "q": self.q.state_dict(), + "q_target":self.q_target.state_dict(), + } + + def load_save_dict(self, saved_dict): + self.q.load_state_dict(saved_dict["q"]) + self.q_target.load_state_dict(saved_dict["q_target"]) + + def update(self, buffer, step): + metrics = dict() + if step % self.log_interval == 0 and self.wandb_log: + log = True + else: + log = False + + if step % self.update_interval == 0: + state_batch, action_batch, reward_batch, next_state_batch, done_batch, time_batch = buffer.sample(self.batch_size) + state_batch = torch.tensor(state_batch, dtype=torch.float32, device=self.device) + next_state_batch = torch.tensor(next_state_batch, dtype=torch.float32, device=self.device) + action_batch = torch.tensor(action_batch, dtype=torch.long, device=self.device) + reward_batch = torch.tensor(reward_batch, dtype=torch.float32, device=self.device) + done_batch = torch.tensor(done_batch, dtype=torch.float32, device=self.device) + discount_batch = self.gamma*(1-done_batch) + + with torch.no_grad(): + target_max, _ = self.q_target(next_state_batch).max(dim=1) + td_target = reward_batch + self.gamma * target_max * discount_batch + + old_val = self.q(state_batch).gather(1, action_batch).squeeze() + + loss = F.mse_loss(td_target, old_val) + self.q_opt.zero_grad() + loss.backward() + self.q_opt.step() + + if log: + metrics['mean_q_target'] = torch.mean(td_target).item() + metrics['max_reward'] = torch.max(reward_batch).item() + metrics['min_reward'] = torch.min(reward_batch).item() + metrics['variance_q_target'] = torch.var(td_target).item() + metrics['min_q_target'] = torch.min(td_target).item() + metrics['max_q_target'] = torch.max(td_target).item() + metrics['critic_loss'] = loss.item() + + if step % self.target_update_interval == 0: + utils.soft_update(self.q_target, self.q, self.tau) + + if log: + wandb.log(metrics, step=step) \ No newline at end of file diff --git a/train.py b/train.py index 472f41e..2c7d8d8 100644 --- a/train.py +++ b/train.py @@ -41,6 +41,13 @@ def make_agent(env, device, cfg): agent = SacAgent(device, obs_dims, num_actions, cfg.gamma, cfg.tau, cfg.policy_update_interval, cfg.target_update_interval, cfg.lr, cfg.batch_size, cfg.entropy_coefficient, cfg.hidden_dims, cfg.wandb_log, cfg.agent_log_interval) + + elif cfg.agent == 'dqn': + from dqn import DQNAgent + agent = DQNAgent(device, obs_dims, num_actions, cfg.gamma, cfg.tau, + cfg.update_interval, cfg.target_update_interval, cfg.lr, cfg.batch_size, + cfg.hidden_dims, cfg.wandb_log, cfg.agent_log_interval) + else: raise NotImplementedError return agent, env_buffer, fresh_env_buffer, docking_buffer @@ -97,14 +104,20 @@ def explore(cfg, train_env, env_buffer, fresh_env_buffer, docking_buffer): return parallel_reward_batch -def collect_molecule(env, agent, fresh_env_buffer): +def collect_molecule(env, agent, fresh_env_buffer, train_step, num_train_steps): state, done, t = env.reset(), False, 0 while not done: - action = agent.get_action(state) + epsilon = utils.linear_schedule(1, 0.05, 0.5 * num_train_steps, train_step+t) + if random.random() < epsilon: + action = np.random.randint(env.num_actions) + else: + action = agent.get_action(state) next_state, reward, done, info = env.step(action) fresh_env_buffer.push((state, action, reward, next_state, done, t)) t += 1 state = next_state + + info['episode']['last epsilon'] = epsilon return info['episode'] def train(cfg): @@ -128,8 +141,7 @@ def train(cfg): unique_molecule_counter = 0 while train_step < cfg.num_train_steps: - - episode_info = collect_molecule(train_env, agent, fresh_env_buffer) + episode_info = collect_molecule(train_env, agent, fresh_env_buffer, train_step, cfg.num_train_steps) molecule_counter += 1 if docking_buffer[episode_info['smiles']] is not None: fresh_env_buffer.update_last_episode_reward(docking_buffer[episode_info['smiles']]) @@ -137,7 +149,7 @@ def train(cfg): docking_buffer[episode_info['smiles']] = 0 train_env._add_smiles_to_batch(episode_info['smiles']) unique_molecule_counter += 1 - + for _ in range(episode_info['l']): agent.update(env_buffer, train_step) train_step += 1