Skip to content

Commit

Permalink
adding data aug one hot conv experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
RajGhugare19 committed Feb 20, 2023
1 parent a49c119 commit 62af37a
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 9 deletions.
14 changes: 8 additions & 6 deletions sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import utils

class NoiseAug(nn.Module):
def __init__(self, noise=0.95):
def __init__(self, noise=0.5):
super().__init__()
self.noise = noise

def forward(self, x):
n, w = x.size()
x = x + torch.rand((n, w), device=x.device)
n, w, h = x.size()
x = x + torch.rand((n, w, h), device=x.device)
return torch.clamp(x, max=1)

class Encoder(nn.Module):
Expand All @@ -29,9 +29,12 @@ def __init__(self, vocab_size, padding_id, device):
nn.ReLU(), nn.Conv1d(32, 32, kernel_size=9),
nn.ReLU())
self.apply(utils.weight_init)
self.aug = NoiseAug()

def forward(self, x):
def forward(self, x, aug=True):
x = self.embedding(x).view(-1, 70, 25)
if aug:
x = self.aug(x)
x = self.convnet(x)
x = x.view(x.shape[0], -1)
return x
Expand Down Expand Up @@ -94,7 +97,6 @@ def __init__(self, device, obs_dims, num_actions, vocab_size, padding_id,
self.policy_update_interval = policy_update_interval
self.target_update_interval = target_update_interval
self.batch_size = batch_size
self.aug = NoiseAug()

#exploration
self.entropy_coefficient = entropy_coefficient
Expand All @@ -110,7 +112,7 @@ def __init__(self, device, obs_dims, num_actions, vocab_size, padding_id,
def get_action(self, obs, step, eval=False):
with torch.no_grad():
obs = torch.LongTensor(obs).to(self.device)
obs = self.encoder(obs)
obs = self.encoder(obs, aug=False)
action_dist = self.actor(obs)
action = action_dist.sample()
if eval:
Expand Down
2 changes: 1 addition & 1 deletion scripts/cedar/1.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --excl
echo "moved code to slurm tmpdir"

singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem.sif bash -c "source activate rl4chem && cd RL4Chem &&\
python train.py target=${array[SLURM_ARRAY_TASK_ID]} wandb_log=True wandb_run_name=onehotconv_max_len_25_1 seed=1 num_sub_proc=20"
python train.py target=${array[SLURM_ARRAY_TASK_ID]} wandb_log=True wandb_run_name=drqconv_max_len_25_1 seed=1 num_sub_proc=20"
2 changes: 1 addition & 1 deletion scripts/cedar/2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --excl
echo "moved code to slurm tmpdir"

singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem.sif bash -c "source activate rl4chem && cd RL4Chem &&\
python train.py target=${array[SLURM_ARRAY_TASK_ID]} wandb_log=True wandb_run_name=onehotconv_max_len_25_2 seed=2 num_sub_proc=20"
python train.py target=${array[SLURM_ARRAY_TASK_ID]} wandb_log=True wandb_run_name=drqconv_max_len_25_2 seed=2 num_sub_proc=20"
2 changes: 1 addition & 1 deletion scripts/cedar/3.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ rsync -a $HOME/projects/def-gberseth/$USER/RL4Chem/ $SLURM_TMPDIR/RL4Chem --excl
echo "moved code to slurm tmpdir"

singularity exec --nv --home $SLURM_TMPDIR --env WANDB_API_KEY="7da29c1c6b185d3ab2d67e399f738f0f0831abfc",REQUESTS_CA_BUNDLE="/usr/local/envs/rl4chem/lib/python3.11/site-packages/certifi/cacert.pem",HYDRA_FULL_ERROR=1 $SCRATCH/rl4chem.sif bash -c "source activate rl4chem && cd RL4Chem &&\
python train.py target=${array[SLURM_ARRAY_TASK_ID]} wandb_log=True wandb_run_name=onehotconv_max_len_25_3 seed=3 num_sub_proc=20"
python train.py target=${array[SLURM_ARRAY_TASK_ID]} wandb_log=True wandb_run_name=drqconv_max_len_25_3 seed=3 num_sub_proc=20"

0 comments on commit 62af37a

Please sign in to comment.