Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor rl examples - updated README #223

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ https://lab.mlpack.org/.

### 0. Contents

1. [Overview](#1-overview)
2. [Building the examples and usage](#2-Building-the-examples-and-usage)
3. [List of examples](#3-List-of-examples)
4. [Datasets](#4-datasets)
- [0. Contents](#0-contents)
- [1. Overview](#1-overview)
- [2. Building the examples and usage](#2-building-the-examples-and-usage)
- [3. List of examples](#3-list-of-examples)
- [4. Datasets](#4-datasets)
- [5. Setup](#5-setup)

### 1. Overview

Expand Down Expand Up @@ -93,3 +95,9 @@ extract all the necessary dataset in order for examples to work perfectly:
cd tools/
./download_data_set.py
```

### 5. Setup
To setup a jupyter local environment that work with C++ using xeus-cling you shall execute the following command:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I don't think the intention was for users to run that script directly. It would be better to just use Binderhub or similar.

It's true that you could run this script, but it has a number of assumptions that may not be true for users:

  1. Users may not be using conda.
  2. Users may not be interested in the C++ notebook examples at all, but might be using the Makefile-built examples.
  3. Users may not even be interested in C++ at all and may be focusing on other languages.

So I don't think that I would want to include this in the general README; users will then attempt to run the command, and may encounter problems that may not even be relevant if they're not looking to use Jupyterlab.

I think as an alternative it may be more reasonable to comment that script a little bit better. Or, if we restructured the examples in the repository to organize them by language, then perhaps in a directory specific to C++ notebook examples, it makes more sense to have this documentation.

```sh
./script/jupyter-conda-setup.sh <environment_name>
```
32 changes: 17 additions & 15 deletions reinforcement_learning_gym/acrobot_dqn/acrobot_dqn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
using namespace mlpack;
using namespace ens;

// Set up the state and action space.
constexpr size_t stateDimension = 6;
constexpr size_t actionSize = 3;

// Function to train the agent on the Acrobot-v1 gym environment.
template<typename EnvironmentType,
typename NetworkType,
Expand Down Expand Up @@ -49,7 +53,7 @@ void Train(gym::Environment& env,
arma::mat action = {double(agent.Action().action)};

env.step(action);
DiscreteActionEnv::State nextState;
DiscreteActionEnv<stateDimension, actionSize>::State nextState;
nextState.Data() = env.observation;

replayMethod.Store(
Expand Down Expand Up @@ -85,22 +89,22 @@ void Train(gym::Environment& env,
int main()
{
// Initializing the agent.
// Set up the state and action space.
DiscreteActionEnv::State::dimension = 6;
DiscreteActionEnv::Action::size = 3;

// Set up the network.
FFN<MeanSquaredError, GaussianInitialization> module(
MeanSquaredError(), GaussianInitialization(0, 1));
module.Add<Linear>(DiscreteActionEnv::State::dimension, 64);
module.Add<ReLULayer>();
module.Add<Linear>(64, DiscreteActionEnv::Action::size);
module.Add<Linear>(64);
module.Add<ReLU>();
module.Add<Linear>(actionSize);
SimpleDQN<> model(module);

// Set up the policy method.
GreedyPolicy<DiscreteActionEnv> policy(1.0, 1000, 0.1, 0.99);
GreedyPolicy<DiscreteActionEnv<stateDimension, actionSize>>
policy(1.0, 1000, 0.1, 0.99);

// To enable 3-step learning, we set the last parameter of the replay method as 3.
PrioritizedReplay<DiscreteActionEnv> replayMethod(64, 5000, 0.6, 3);
PrioritizedReplay<DiscreteActionEnv<stateDimension, actionSize>>
replayMethod(64, 5000, 0.6, 3);

// Set up training configurations.
TrainingConfig config;
Expand All @@ -111,7 +115,7 @@ int main()
config.DoubleQLearning() = true;

// Set up DQN agent.
QLearning<DiscreteActionEnv,
QLearning<DiscreteActionEnv<stateDimension, actionSize>,
decltype(model),
AdamUpdate,
decltype(policy),
Expand All @@ -120,7 +124,7 @@ int main()

// Preparation for training the agent
// Set up the gym training environment.
gym::Environment env("gym.kurg.org", "4040", "Acrobot-v1");
gym::Environment env("localhost", "4040", "Acrobot-v1");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure of the status of gym.kurg.org, but I don't know if this is the right thing to do here, otherwise we would now need to expect a user to be running the gym locally.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True this would need the user to be running gym locally. Running gym_tcp_api locally was the only way I could get it to work. I couldn't find any working examples using gym.kurg.org, so I assumed it's not functional anymore. Also, the example in the gym_tcp_api directory used localhost.


// Initializing training variables.
std::vector<double> returnList;
Expand All @@ -144,7 +148,7 @@ int main()
agent.Deterministic() = true;

// Creating and setting up the gym environment for testing.
gym::Environment envTest("gym.kurg.org", "4040", "Acrobot-v1");
gym::Environment envTest("localhost", "4040", "Acrobot-v1");
envTest.monitor.start("./dummy/", true, true);

// Resets the environment.
Expand Down Expand Up @@ -182,7 +186,6 @@ int main()
// << totalReward << "\t Action taken: " << action;
}

envTest.close();
std::cout << envTest.url() << std::endl;

/**
Expand All @@ -194,10 +197,9 @@ int main()

// Creating and setting up the gym environment for testing.
envTest.monitor.start("./dummy/", true, true);

// Resets the environment.
envTest.reset();
envTest.render();

totalReward = 0;
totalSteps = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ using namespace mlpack;
using namespace ens;
using namespace gym;

// Set up the state and action space.
constexpr size_t stateDimension = 24;
constexpr size_t actionSize = 4;

template<typename EnvironmentType,
typename NetworkType,
typename UpdaterType,
Expand Down Expand Up @@ -48,7 +52,7 @@ void Train(gym::Environment& env,
arma::mat action = {agent.Action().action};

env.step(action);
ContinuousActionEnv::State nextState;
ContinuousActionEnv<stateDimension, actionSize>::State nextState;
nextState.Data() = env.observation;

replayMethod.Store(
Expand Down Expand Up @@ -94,10 +98,6 @@ void Train(gym::Environment& env,
int main()
{
// Initializing the agent.
// Set up the state and action space.
ContinuousActionEnv::State::dimension = 24;
ContinuousActionEnv::Action::size = 4;

bool usePreTrainedModel = true;

// Set up the actor and critic networks.
Expand All @@ -107,9 +107,8 @@ int main()
policyNetwork.Add<ReLU>();
policyNetwork.Add<Linear>(128);
policyNetwork.Add<ReLU>();
policyNetwork.Add<Linear>(ContinuousActionEnv::Action::size);
policyNetwork.Add<Linear>(actionSize);
policyNetwork.Add<TanH>();
policyNetwork.ResetParameters();

FFN<EmptyLoss, GaussianInitialization> qNetwork(
EmptyLoss(), GaussianInitialization(0, 0.01));
Expand All @@ -118,10 +117,11 @@ int main()
qNetwork.Add<Linear>(128);
qNetwork.Add<ReLU>();
qNetwork.Add<Linear>(1);
qNetwork.ResetParameters();


// Set up the replay method.
RandomReplay<ContinuousActionEnv> replayMethod(32, 10000);
RandomReplay<ContinuousActionEnv<stateDimension, actionSize>>
replayMethod(32, 10000);

// Set up training configurations.
TrainingConfig config;
Expand All @@ -148,14 +148,14 @@ int main()
* To default is to use the usePreTrainedModel. Otherwise you can disable this
* by change the usePreTrainedModel to false and then recompile this example.
*/
SAC<ContinuousActionEnv,
SAC<ContinuousActionEnv<stateDimension, actionSize>,
decltype(qNetwork),
decltype(policyNetwork),
AdamUpdate>
agent(config, qNetwork, policyNetwork, replayMethod);

const std::string environment = "BipedalWalker-v3";
const std::string host = "gym.kurg.org";
const std::string host = "127.0.0.1";
const std::string port = "4040";

Environment env(host, port, environment);
Expand Down Expand Up @@ -187,7 +187,7 @@ int main()
agent.Deterministic() = true;

// Creating and setting up the gym environment for testing.
gym::Environment envTest("gym.kurg.org", "4040", "BipedalWalker-v3");
gym::Environment envTest(host, port, environment);
envTest.monitor.start("./dummy/", true, true);

// Resets the environment.
Expand Down Expand Up @@ -218,7 +218,7 @@ int main()
if (envTest.done)
{
std::cout << " Total steps: " << totalSteps
<< "\\t Total reward: " << totalReward << std::endl;
<< "\t Total reward: " << totalReward << std::endl;
break;
}

Expand Down
48 changes: 29 additions & 19 deletions reinforcement_learning_gym/cartpole_dqn/cartpole_dqn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
using namespace mlpack;
using namespace ens;

// Set up the state and action space.
constexpr size_t stateDimension = 4;
constexpr size_t actionSize = 2;

template<typename EnvironmentType,
typename NetworkType,
typename UpdaterType,
Expand Down Expand Up @@ -45,7 +49,7 @@ void Train(
arma::mat action = {double(agent.Action().action)};

env.step(action);
DiscreteActionEnv::State nextState;
DiscreteActionEnv<stateDimension, actionSize>::State nextState;
nextState.Data() = env.observation;

replayMethod.Store(
Expand Down Expand Up @@ -75,26 +79,30 @@ void Train(
{
std::cout << "Avg return in last " << consecutiveEpisodes
<< " episodes: " << averageReturn
<< "\\t Episode return: " << episodeReturn
<< "\\t Total steps: " << agent.TotalSteps() << std::endl;
<< "\t Episode return: " << episodeReturn
<< "\t Total steps: " << agent.TotalSteps() << std::endl;
}
}
}

int main()
{
// Initializing the agent.
// Set up the state and action space.
DiscreteActionEnv::State::dimension = 4;
DiscreteActionEnv::Action::size = 2;
// Set up the network.
SimpleDQN<> model(DiscreteActionEnv::State::dimension,
128,
32,
DiscreteActionEnv::Action::size);
FFN<MeanSquaredError, GaussianInitialization> network(
MeanSquaredError(), GaussianInitialization(0, 1));
network.Add<Linear>(128);
network.Add<ReLU>();
network.Add<Linear>(actionSize);

SimpleDQN<> model(network);

// Set up the policy and replay method.
GreedyPolicy<DiscreteActionEnv> policy(1.0, 1000, 0.1, 0.99);
RandomReplay<DiscreteActionEnv> replayMethod(32, 2000);
GreedyPolicy<DiscreteActionEnv<stateDimension, actionSize>>
policy(1.0, 1000, 0.1, 0.99);
RandomReplay<DiscreteActionEnv<stateDimension, actionSize>>
replayMethod(32, 2000);

// Set up training configurations.
TrainingConfig config;
config.StepSize() = 0.001;
Expand All @@ -103,12 +111,16 @@ int main()
config.ExplorationSteps() = 100;
config.DoubleQLearning() = false;
config.StepLimit() = 200;

// Set up DQN agent.
QLearning<DiscreteActionEnv, decltype(model), AdamUpdate, decltype(policy)>
QLearning<DiscreteActionEnv<stateDimension, actionSize>,
decltype(model),
AdamUpdate, decltype(policy)>
agent(config, model, policy, replayMethod);

// Preparation for training the agent.
// Set up the gym training environment.
gym::Environment env("gym.kurg.org", "4040", "CartPole-v0");
gym::Environment env("localhost", "4040", "CartPole-v0");

// Initializing training variables.
std::vector<double> returnList;
Expand All @@ -133,7 +145,7 @@ int main()
agent.Deterministic() = true;

// Creating and setting up the gym environment for testing.
gym::Environment envTest("gym.kurg.org", "4040", "CartPole-v0");
gym::Environment envTest("localhost", "4040", "CartPole-v0");
envTest.monitor.start("./dummy/", true, true);

// Resets the environment.
Expand Down Expand Up @@ -162,7 +174,7 @@ int main()
if (envTest.done)
{
std::cout << " Total steps: " << totalSteps
<< "\\t Total reward: " << totalReward << std::endl;
<< "\t Total reward: " << totalReward << std::endl;
break;
}

Expand All @@ -171,7 +183,6 @@ int main()
// << totalReward << "\\t Action taken: " << action;
}

envTest.close();
std::cout << envTest.url() << std::endl;

// A little more training...
Expand All @@ -191,7 +202,6 @@ int main()

// Resets the environment.
envTest.reset();
envTest.render();

totalReward = 0;
totalSteps = 0;
Expand All @@ -215,7 +225,7 @@ int main()
if (envTest.done)
{
std::cout << " Total steps: " << totalSteps
<< "\\t Total reward: " << totalReward << std::endl;
<< "\t Total reward: " << totalReward << std::endl;
break;
}

Expand Down
Loading