AaronAnima / TarGF

Official Implementation of Learning Gradient Fields for Object Rearrangement

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TarGF: Learning Target Gradient Field to Rearrange Objects without Explicit Goal Specification

[Website] [Arxiv]

The agent is given examples from a target distribution and aims at rearranging objects to increase the likelihood of the distribution. Our key idea is to learn a target gradient field that indicates the fastest direction to increase the likelihood from examples via score-matching. We further incoporates the target gradient field with reinforcement learning or model-based planner to tackle this task in model-free and model-based setting respectively.

The environments used in this work are demonstrated as follows:

Circling Clustering Circling + Clustering Room Rearrangement

This repo is the official implementation of TarGF. The codes is rebuilt from the old branch and is released on 2023/05/03. Contents of this repo are as follows:

Installation

Requirements

  • Ubuntu >= 18.04
  • Anaconda3
  • python >= 3.9
  • pytorch >= 1.11.0
  • pytorch_geometric >= 2.0.0
  • pybullet >= 3.1.0
  • tensorboard >= 2.6.0
  • pytorch_fid >= 0.2.0
  • imageio >= 2.9.0

Install Global Dependencies

git clone https://github.com/AaronAnima/TarGF

cd TarGF

conda create -n targf python=3.9

conda activate targf

conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch

conda install pyg==2.0.4 -c pyg

pip install opencv-python tensorboard pytorch_fid ipdb imageio 

Install Ball Rearrangement Environment

pip install gym pybullet

cd envs

git clone https://github.com/AaronAnima/EbOR # install Example-based Object Rearrangement (EbOR) environments

cd EbOR

pip install -e .

cd ../../

Install Room Rearrangement Environment

Please follow the README in this page.

If you do not need to run this experiment, you can skip this procedure.

Training

We assign an argument --log_dir $log_dir for each experiment. The in-process results will be saved in ../logs/${log_dir}.

Training the Target Score Network

Note: To reproduce the results in the paper, please change config.n_samples = 10000 to config.n_samples = 100000 in ./configs/targf_ball_base.py

Set --mode train_gf and set config files for specifying different tasks. The in-process results and the model checkpoint will be saved in ./logs/${workdir}. For example, for Room Rearrangement:

python main.py --config configs/targf_room.py --workdir Room_Score --mode train_gf

You can also visualise the in-process results via TensorBoard:

tensorboard --logdir ../logs/${workdir}/tb --port 10020

Learning to control with RL and TarGF (Optional)

The TarGF provides reference actions and rewards in reinforcement learning (RL). Here we choose SAC as the RL algorithm. This is an optional procedure. If you do not need to run this experiment, you can skip this procedure. You can use the ORCA planner to control the balls in the Ball Rearrangement environment. Note that ORCA only supports Ball Rearrangement environment. For Room Rearrangement, please use the TarGF (SAC). Before the RL's training, please specify config.score_exp in the config file. The trainer will load the score network's checkpoint saved in ./logs/${config.score_exp} for reward estimation and computing the gradient-based action. Note: To reproduce the results in the paper, please change config.residual_t0 = 0.01 to config.residual_t0 = 0.1 in the config file for all the ball rearrangement experiments.

An example command for Circling:

python main.py --config configs/targf_Circle.py --workdir Circle_SAC --mode train_rl

Evaluation

By specifying --mode test_policy, we can use main.py for evaluation. Similarly, the policy type (e.g., TarGF(SAC) and TarGF(ORCA)) and the environment can be specified by the config file. For instance, we can set config.policy_type = 'targf_orca' in ./configs/targf_circlecluster.pyto evaluate TarGF(ORCA) on CircleCluster environment. To evaluate TarGF(SAC), we need to further specify config.policy_exp = ${rl_workdir} in the config file. An example command to evaluate the TarGF(ORCA):

python main.py --config configs/targf_circlecluster.py --workdir CircleCluster_SAC_eval --mode test_policy

The metrics dict, trajectories and visualisations will be saved in ../logs/${workdir}.

Note: To reproduce the results in the paper, please change config.residual_t0 = 0.01 to config.residual_t0 = 0.1 and config.is_decay_t0_orca = False to config.is_decay_t0_orca = True in ./configs/targf_ball_base.py

Citation

@inproceedings{
wu2022targf,
title={Tar{GF}: Learning Target Gradient Field for Object Rearrangement},
author={Mingdong Wu and Fangwei Zhong and Yulong Xia and Hao Dong},
booktitle={Advances in Neural Information Processing Systems},
editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
year={2022},
url={https://openreview.net/forum?id=Euv1nXN98P3}
}

Contact

If you have any suggestion or questions, please get in touch at wmingd@pku.edu.cn or zfw@pku.edu.cn.

LICENSE

TarGF has an MIT license, as found in the LICENSE file.

About

Official Implementation of Learning Gradient Fields for Object Rearrangement

License:MIT License


Languages

Language:Python 100.0%