xingdi-eric-yuan / gata

GATA replication

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Graph-aided Transformer Agent (GATA) Replication

Dependency Installations

Install Python 3.7 or above and run the following commands.

$ pip install -r requirements.txt
# if developing
$ pip install -r requirements-dev.txt

Observation Generation

First, download the training data by following the instructions here, then unzip under data/obs_gen.0.1.

# train graph updater via observation generation with one GPU
$ python +pl_trainer.gpus=1

Reinforcement Learning

Download the training data by following the instructions here, then unzip under data/rl.0.2.

# train GATA at difficulty level 5 and training size 100 with one GPU
$ python +pl_trainer.gpus=1

# train GATA at difficulty level 3 and training size 20 with one GPU
$ python +pl_trainer.gpus=1 data.difficulty_level=3 data.train_data_size=20


You can run the following command to have an agent play a game.

python /path/to/game.z8 /path/to/gata.ckpt

Pretrained Weights

There are some pretrained weights under /weights. One for the graph updater, one for GATA trained at difficulty level 5 with 20 training games and one for GATA trained at difficulty level 5 with 100 training games.


(We don't use Fairscale anymore, but leaving it for posterity.)

If pip install fairscale fails, try with --no-build-isolation. If it then fails with unsupported GNU version! gcc versions later than 7 are not supported!, run the following commands to have nvcc use the correct gcc:

$ sudo ln -s /usr/bin/gcc-7 /usr/local/cuda/bin/gcc
$ sudo ln -s /usr/bin/g++-7 /usr/local/cuda/bin/g++

If it fails with fatal error: cublas_v2.h: No such file or directory or fatal error: cublas_api.h: No such file or directory, take a look at the include directories of the compiler command, and symlink properly. An example set of commands:

$ sudo ln -s /usr/local/cuda-10.2/targets/x86_64-linux/include/cublas_v2.h /usr/local/cuda/include/cublas_v2.h
$ sudo ln -s /usr/local/cuda-10.2/targets/x86_64-linux/include/cublas_api.h /usr/local/cuda/include/cublas_api.h


GATA replication

License:MIT License


Language:Python 100.0%