azadsalam / clutr-public

The public repository for our ICML'23 paper CLUTR: Curriculum Learning via Unsupervised Task Representation Learning

Home Page:https://dl.acm.org/doi/10.5555/3618408.3618465

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

CLUTR: Curriculum Learning via Unsupervised Task Representation Learning

This codebase provides the implementation of CLUTR: Curriculum Learning via Unsupervised Task Representation Learning. The CLUTR algorithm is implemented on top of the PyTorch UED framework: Dual Curriculum Design (DCD), which also includes PAIRED. CLUTR Recurrent VAE (task_embed/clutr_RVAE) uses the PyTorch implementation of Samuel Bowman's Generating Sentences from a Continuous Space found here, slightly modified to use random embeddings, instead of the default word embeddings.

Setup

To install the necessary dependencies, run the following commands:

conda create --name clutr python=3.8 -y
conda activate clutr
pip install six
pip install -r requirements.txt
git clone https://github.com/openai/baselines.git
cd baselines
pip install -e .
cd ..
pip install pyglet==1.5.11

Training

The scripts directory also contains the necessary scripts to train the VAE and the CLUTR algorithm. Descriptions of the arguments can be found in arguments.py.

Evaluating trained agents

eval.py is used to evaluate agents on specific environments. The following command evaluates a <model>.tar in an experiment results directory, <xpid>, in a base log output directory <log_dir> for <num_episodes> episodes in each of the environments named <env_name1>, <env_name1>, and <env_name1>, and outputs the results as a .csv in <result_dir>.

python -m eval \
--base_path <log_dir> \
--xpid <xpid> \
--model_tar <model>
--env_names <env_name1>,<env_name2>,<env_name3> \
--num_episodes <num_episodes> \
--result_path <result_dir>

About

The public repository for our ICML'23 paper CLUTR: Curriculum Learning via Unsupervised Task Representation Learning

https://dl.acm.org/doi/10.5555/3618408.3618465

License:Other


Languages

Language:Python 99.9%Language:Shell 0.1%