This repo is the official code release for the ICML 2024 conference paper:
Rethinking Transformers in Solving POMDPs
|
In this work, we challenge the suitability of Transformers as sequence models in Partially Observable RL by leveraging regular language and circuit complexity theories. We advocate Linear RNNs as a promising alternative.
In the paper, we compare representative models including GPT, LSTM, and LRU on three different tasks to validate our theory through experiments. This codebase is used to reproduce the experimental results from the paper.
Run the following commands.
cd pomdp-discrete
conda create -n tfporl-discrete python=3.8
pip install -r requirements.txt
Run the following commands.
cd pomdp-discrete
conda create -n tfporl-continuous python=3.8
pip install -r requirements.txt
If you meet any problems, please refer to the guidance in JAX.
We can only guarantee the reproducibility with the environment configuration as below.
First, you need to download the file from this link and tar -xvf the_file_name
in the ~/.mujoco
folder. Then, run the following commands.
cd defog
conda create -n tfporl-defog python=3.8.17
After that, add the following lines to your ~/.bashrc
file:
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/YOUR_PATH_TO_THIS/.mujoco/mujoco210/bin
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia
Remember to source ~/.bashrc
to make the changes take effect.
Install D4RL by following the guidance in D4RL.
Degrade the dm-control and mujoco package:
pip install mujoco==2.3.7
pip install dm-control==1.0.14
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
pip install -r requirements.txt
To download original D4RL data,
python download_d4rl_datasets.py
After installing packages, you can run the following script to reproduce results:
cd pomdp-discrete
# for regular language tasks
python main.py \
--config_env configs/envs/regular_parity.py \
--config_env.env_name 25 \
--config_rl configs/rl/dqn_default.py \
--train_episodes 40000 \
--config_seq configs/seq_models/gpt_default.py \
--config_seq.model.seq_model_config.n_layer {n_layer} \
--config_seq.sampled_seq_len -1 \
--config_seq.model.action_embedder.hidden_size=0 \
--config_rl.config_critic.hidden_dims="()"
# for Passive T-maze
python main.py \
--config_env configs/envs/tmaze_passive.py \
--config_env.env_name 50 \
--config_rl configs/rl/dqn_default.py \
--train_episodes 20000 \
--config_seq configs/seq_models/lstm_default.py \
--config_seq.sampled_seq_len -1 \
# for Passive Visual Match
python main.py \
--config_env configs/envs/visual_match.py \
--config_env.env_name 60 \
--config_rl configs/rl/sacd_default.py \
--shared_encoder --freeze_critic \
--train_episodes 40000 \
--config_seq configs/seq_models/gpt_cnn.py \
--config_seq.sampled_seq_len -1 \
In the scripts, env_name
is the max training length of regular langauge task. You can try other regular language tasks in pomdp-discretes/configs/envs/
. and other sequence model in pomdp-discretes/configs/seq_models/
.
Feel free to add other regular language in pomdp-discretes/envs/regular.py
by input its DFA.
After installing packages, you can run the following script to reproduce results:
python main.py \
--config_env configs/envs/pomdps/pybullet_p.py \
--config_env.env_name cheetah \
--config_rl configs/rl/td3_default.py \
--config_seq configs/seq_models/lstm_default.py \
--config_seq.sampled_seq_len 64 \
--train_episodes 1500 \
--shared_encoder --freeze_all \
In the scripts, env_name
is the control task type, including ant
, walker
, cheetah
, and hopper
. You can change the pomdp by replacing pybullet_p
with pybullet_v
. and other sequence model in pomdp-continuous/configs/seq_models/
.
After installing the packages and data, you can run the following script to reproduce results:
cd defog
python main.py env=hopper model=dt
You can replace hopper
with halfcheetah
, walker2d
. You can also replace dt
with dlstm
or dlru
to test more sequence model.
The code is largely based on prior works:
This work is licensed under the MIT license. See the LICENSE file for details.
If you find our work useful, please consider citing:
@article{Lu2024Rethink,
title={Rethinking Transformers in Solving POMDPs},
author={Chenhao Lu and Ruizhe Shi and Yuyao Liu and Kaizhe Hu and Simon S. Du and Huazhe Xu},
journal={International Conference on Machine Learning},
year={2024}
}