Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wandb integration #25

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions RED/configs/example/Figure_3_RT3D_chemostat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ defaults:
- /model: RT3D_agent
- _self_

wandb_project_name: figure3-example
wandb_team: rl-oed

policy_delay: 2
initial_explore_rate: 1
explore_rate_mul: 1
Expand Down
3 changes: 3 additions & 0 deletions RED/configs/example/Figure_4_RT3D_chemostat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ defaults:
- /model: RT3D_agent
- _self_

wandb_project_name: figure4-example
wandb_team: rl-oed

policy_delay: 2
initial_explore_rate: 1
explore_rate_mul: 1
Expand Down
33 changes: 33 additions & 0 deletions WANDB_LOGIN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
## 1. Get your w and b api key
The weights and bias api key can be found by logging into the rl-oed team here: https://stability.wandb.io/.
You will need access to the stability cluster first, message NeythenT on discord to get help
with this

## 2. Set the WANDB_API_KEY login variable
Set the WANDB_API_KEY environment variable to your api key by running
```
$ export WANDB_API_KEY=<YOUR API KEY>
```
from the command line (RECOMMENDED) or
```python
os.environ["WANDB_API_KEY"] = "<YOUR API KEY>"
```
from Python

## Login to w and b
To log in from command line (RECOMMENDED)
```
$ wandb login --host=https://stability.wandb.io
```
or in a python script
```python
wandb.login(host='https://stability.wandb.io', relogin=False)
```

## Running automated slurm jobs
I suggest we add the following lines to the job script that gets pushed to the github and people just copy their api
keys in.
```
$ export WANDB_API_KEY <YOUR API KEY>
$ wandb login --host=https://stability.wandb.io
```
37 changes: 27 additions & 10 deletions examples/Figure_3_RT3D_chemostat/train_RT3D.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@


import json

import math
import os
import sys
Expand All @@ -15,7 +17,7 @@
from casadi import *
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf

import wandb
from RED.agents.continuous_agents.rt3d import RT3D_agent
from RED.environments.chemostat.xdot_chemostat import xdot
from RED.environments.OED_env import OED_env
Expand All @@ -26,7 +28,7 @@


@hydra.main(version_base=None, config_path="../../RED/configs", config_name="example/Figure_3_RT3D_chemostat")
def train_RT3D(cfg : DictConfig):
def train_RT3D(cfg: DictConfig):
### config setup
cfg = cfg.example
print(
Expand All @@ -36,6 +38,9 @@ def train_RT3D(cfg : DictConfig):
sep="\n\n"
)

# start a new wandb run to track this script
wandb.init(project=cfg.wandb_project_name, entity=cfg.wandb_team, config=dict(cfg))

### prepare save path
os.makedirs(cfg.save_path, exist_ok=True)
print("Results will be saved in: ", cfg.save_path)
Expand Down Expand Up @@ -84,7 +89,7 @@ def train_RT3D(cfg : DictConfig):
size=(cfg.environment.n_parallel_experiments, n_params)
)
env.param_guesses = DM(actual_params)

### episode buffers for agent
states = [env.get_initial_RL_state_parallel() for i in range(cfg.environment.n_parallel_experiments)]
trajectories = [[] for _ in range(cfg.environment.n_parallel_experiments)]
Expand All @@ -110,20 +115,21 @@ def train_RT3D(cfg : DictConfig):
if episode < skip_first_n_episodes:
actions = agent.get_actions(inputs, explore_rate=1, test_episode=cfg.test_episode, recurrent=True)
else:
actions = agent.get_actions(inputs, explore_rate=explore_rate, test_episode=cfg.test_episode, recurrent=True)
actions = agent.get_actions(inputs, explore_rate=explore_rate, test_episode=cfg.test_episode,
recurrent=True)
e_actions.append(actions)

### step env
outputs = env.map_parallel_step(actions.T, actual_params, continuous=True)
next_states = []
for i, obs in enumerate(outputs):
state, action = states[i], actions[i]
next_state, reward, done, _, u = obs
next_state, reward, done, _, u = obs

### set done flag
if control_interval == cfg.environment.N_control_intervals - 1 \
or np.all(np.abs(next_state) >= 1) \
or math.isnan(np.sum(next_state)):
or np.all(np.abs(next_state) >= 1) \
or math.isnan(np.sum(next_state)):
done = True

### memorize transition
Expand All @@ -134,8 +140,11 @@ def train_RT3D(cfg : DictConfig):
### log episode data
e_us[i].append(u.tolist())
next_states.append(next_state)


e_rewards[i].append(reward)
e_returns[i] += reward

states = next_states

### do not memorize the test trajectory (the last one)
Expand All @@ -146,7 +155,7 @@ def train_RT3D(cfg : DictConfig):
for trajectory in trajectories:
# check for instability
if np.all([np.all(np.abs(trajectory[i][0]) <= 1) for i in range(len(trajectory))]) \
and not math.isnan(np.sum(trajectory[-1][0])):
and not math.isnan(np.sum(trajectory[-1][0])):
agent.memory.append(trajectory)

### train agent
Expand All @@ -173,6 +182,11 @@ def train_RT3D(cfg : DictConfig):
history["us"].extend(e_us)
history["explore_rate"].append(explore_rate)

### log results to w and b
for i in range(len(e_returns)):
wandb.log({"returns": e_returns[i], "actions": np.array(e_actions).transpose(1, 0, 2)[i],
"us": e_us[i], "explore_rate": explore_rate})

print(
f"\nEPISODE: [{episode}/{total_episodes}] ({episode * cfg.environment.n_parallel_experiments} experiments)",
f"explore rate:\t{explore_rate:.2f}",
Expand Down Expand Up @@ -211,6 +225,8 @@ def train_RT3D(cfg : DictConfig):
conv_window=25,
)

wandb.finish()


def setup_env(cfg):
n_cores = multiprocessing.cpu_count()
Expand All @@ -219,12 +235,13 @@ def setup_env(cfg):
n_params = actual_params.size()[0]
param_guesses = actual_params
args = cfg.environment.y0, xdot, param_guesses, actual_params, cfg.environment.n_observed_variables, \
cfg.environment.n_controlled_inputs, cfg.environment.num_inputs, cfg.environment.input_bounds, \
cfg.environment.dt, cfg.environment.control_interval_time, normaliser
cfg.environment.n_controlled_inputs, cfg.environment.num_inputs, cfg.environment.input_bounds, \
cfg.environment.dt, cfg.environment.control_interval_time, normaliser
env = OED_env(*args)
env.mapped_trajectory_solver = env.CI_solver.map(cfg.environment.n_parallel_experiments, "thread", n_cores)
return env, n_params


if __name__ == '__main__':

train_RT3D()
14 changes: 12 additions & 2 deletions examples/Figure_4_RT3D_chemostat/train_RT3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
sys.path.append(IMPORT_PATH)

import multiprocessing

import wandb
import hydra
import numpy as np
from casadi import *
Expand All @@ -35,6 +35,9 @@ def train_RT3D(cfg : DictConfig):
sep="\n\n"
)

# start a new wandb run to track this script
wandb.init(project=cfg.wandb_project_name, entity=cfg.wandb_team, config=dict(cfg))

### prepare save path
os.makedirs(cfg.save_path, exist_ok=True)
print("Results will be saved in: ", cfg.save_path)
Expand Down Expand Up @@ -81,7 +84,7 @@ def train_RT3D(cfg : DictConfig):
actual_params = np.random.uniform(
low=cfg.environment.lb,
high=cfg.environment.ub,
size=(cfg.environment.n_parallel_experiments, 3)
size=(cfg.environment.n_parallel_experiments, n_params)
)
env.param_guesses = DM(actual_params)

Expand Down Expand Up @@ -173,6 +176,11 @@ def train_RT3D(cfg : DictConfig):
history["us"].extend(e_us)
history["explore_rate"].append(explore_rate)

### log results to w and b
for i in range(len(e_returns)):
wandb.log({"returns": e_returns[i], "actions": np.array(e_actions).transpose(1, 0, 2)[i],
"us": e_us[i], "explore_rate": explore_rate})

print(
f"\nEPISODE: [{episode}/{total_episodes}] ({episode * cfg.environment.n_parallel_experiments} experiments)",
f"explore rate:\t{explore_rate:.2f}",
Expand Down Expand Up @@ -211,6 +219,8 @@ def train_RT3D(cfg : DictConfig):
conv_window=25,
)

wandb.finish()


def setup_env(cfg):
n_cores = multiprocessing.cpu_count()
Expand Down