hari-sikchi / SMORe

Official repository for Paper " Score Models for Offline Goal-Conditioned Reinforcement Learning" (ICLR 2024)

Home Page:https://hari-sikchi.github.io/smore/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

SMORe

Official JAX code base for ICLR 2024 paper - SMORE: Score Models for Offline Goal-Conditioned Reinforcement Learning

Harshit Sikchi1, Rohan Chitnis2, Ahmed Touati2, Alborz Geramifard2, Amy Zhang1,2, Scott Niekum3,

1UT Austin

2Meta AI

3UMass Amherst


Paper

How to run the code

Install dependencies

Create an empty conda environment and follow the commands below.

conda create -n smore python=3.9

conda install -c conda-forge cudnn

pip install --upgrade pip

# Install 1 of the below jax versions depending on your CUDA version
## 1. CUDA 12 installation
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

## 2. CUDA 11 installation
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


pip install -r requirements.txt

Offline data

The offline datasets can be downloaded from the google drive link WGCSL offline data. This dataset is provided by prior work WGCSL. Extract the offline data in root-folder/offline_data/*

Example training code

Locomotion

python train_offline_smore.py --double=True --env_name=halfcheetah-medium-v2 --config=configs/gcrl_config.py --eval_episodes=10 --eval_interval=5000  --beta=0.8 --loss_type=<'smore_stable'/'smore'> --exp_name=<exp_name>

Manipulation

python train_offline_smore.py --double=True --env_name=SawyerReach --config=configs/gcrl_config.py --eval_episodes=10 --eval_interval=5000  --beta=0.8 --loss_type=<'smore_stable'/'smore'> --exp_name=<exp_name>

Acknowledgement and Reference

This code base builds upon the following code bases: Extreme Q-learning and Implicit Q-Learning.

About

Official repository for Paper " Score Models for Offline Goal-Conditioned Reinforcement Learning" (ICLR 2024)

https://hari-sikchi.github.io/smore/

License:MIT License


Languages

Language:Python 100.0%