yukw777 / GATA-public

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

GATA: Graph Aided Transformer Agent


Code for NeurIPS 2020 paper Learning Dynamic Belief Graphs to Generalize on Text-Based Games.

# Dependencies
conda create -p /tmp/gata python=3.6 numpy scipy ipython matplotlib cython nltk pillow
source activate /tmp/gata
pip install --upgrade pip
pip install numpy==1.16.2
pip install gym==0.15.4
pip install textworld
pip install -U "spacy<3"
python -m spacy download en
pip install tqdm pipreqs h5py pyyaml visdom
conda install pytorch torchvision cudatoolkit=9.2 -c pytorch

# Download FastText Word Embeddings
curl -L -o crawl-300d-2M.vec.h5 "https://bit.ly/2U3Mde2"

GATA

Pre-training Graph Updater by Observation Generation

# Download data for observation generation / contrastive observation classification
cd obs_gen.0.1 ; wget https://aka.ms/twkg/obs_gen.0.1.zip ; unzip obs_gen.0.1.zip ; cd ..
# Train
python train_obs_generation.py configs/pretrain_observation_generation.yaml

Pre-training Graph Updater by Contrastive Observation Classification

# Download data for observation generation / contrastive observation classification
cd obs_gen.0.1 ; wget https://aka.ms/twkg/obs_gen.0.1.zip ; unzip obs_gen.0.1.zip ; cd ..
# Train
python train_obs_infomax.py configs/pretrain_contrastive_observation_classification.yaml

Train Action Scorer with RL

# Download games
cd rl.0.2 ; wget https://aka.ms/twkg/rl.0.2.zip ; unzip rl.0.2.zip ; cd ..
# Modify configs/train_gata_rl.yaml
#   L30: True to load pre-trained graph encoder, False to randomly initialize.
#     L31:  'gata_pretrain_obs_gen_model', 'gata_pretrain_obs_infomax_model'. When L30 is True.
#   L33:  'gata_pretrain_obs_gen_model' or 'gata_pretrain_obs_infomax_model'
#   L84:  3/7/5/9 correspond to the 1/2/3/4 in paper
#   L85:  1/20/100
#   L125: False/True
# To train
python train_rl_with_continuous_belief.py configs/train_gata_rl.yaml

GATA-GTF

Pre-training Graph Encoder by Action Prediction

# Download data
cd ap.0.2 ; wget https://aka.ms/twkg/ap.0.2.zip ; unzip ap.0.2.zip ; cd ..
# Train
python train_action_prediction.py configs/pretrain_action_prediction_full.yaml

Pre-training Graph Encoder by State Prediction

# Download data
cd sp.0.2 ; wget https://aka.ms/twkg/sp.0.2.zip ; unzip sp.0.2.zip ; cd ..
# Train
python train_state_prediction.py configs/pretrain_state_prediction_full.yaml

Pre-training Graph Encoder by Deep Graph Infomax

# Download data
cd dgi.0.2 ; wget https://aka.ms/twkg/dgi.0.2.zip ; unzip dgi.0.2.zip ; cd ..
# Train
python train_deep_graph_infomax.py configs/pretrain_deep_graph_infomax_full.yaml

Train Action Scorer with RL

# Download games
cd rl.0.2 ; wget https://aka.ms/twkg/rl.0.2.zip ; unzip rl.0.2.zip ; cd ..
# Modify configs/train_gata_gtf_rl.yaml
#   L30: True to load pre-trained graph encoder, False to randomly initialize.
#     L31:  'gata_gtf_pretrain_ap_full_model', 'gata_gtf_pretrain_sp_full_model', or 'gata_gtf_pretrain_dgi_full_model'. When L30 is True.
#   L84:  3/7/5/9 correspond to the 1/2/3/4 in paper
#   L85:  1/20/100
#   L125: False/True
# To train
python train_rl_with_ground_truth_discrete_belief.py configs/train_gata_gtf_rl.yaml

GATA-GTP

Pre-training Graph Encoder by Action Prediction

# Download data
cd ap.0.2 ; wget https://aka.ms/twkg/ap.0.2.zip ; unzip ap.0.2.zip ; cd ..
# Train
python train_action_prediction.py configs/pretrain_action_prediction_seen.yaml

Pre-training Graph Encoder by State Prediction

# Download data
cd sp.0.2 ; wget https://aka.ms/twkg/sp.0.2.zip ; unzip sp.0.2.zip ; cd ..
# Train
python train_state_prediction.py configs/pretrain_state_prediction_seen.yaml

Pre-training Graph Encoder by Deep Graph Infomax

# Download data
cd dgi.0.2 ; wget https://aka.ms/twkg/dgi.0.2.zip ; unzip dgi.0.2.zip ; cd ..
# Train
python train_deep_graph_infomax.py configs/pretrain_deep_graph_infomax_seen.yaml

Pre-training Graph Updater by Command Generation

# Download data for command generation
cd cmd_gen.0.2 ; wget https://aka.ms/twkg/cmd_gen.0.2.zip ; unzip cmd_gen.0.2.zip ; cd ..
# Train
python train_command_generation.py configs/pretrain_command_generation.yaml

Train Action Scorer with RL

# Download games
cd rl.0.2 ; wget https://aka.ms/twkg/rl.0.2.zip ; unzip rl.0.2.zip ; cd ..
# Modify configs/train_gata_gtp_rl.yaml
#   L30: True to load pre-trained graph encoder, False to randomly initialize.
#     L31:  'gata_gtp_pretrain_ap_seen_model', 'gata_gtp_pretrain_sp_seen_model', or 'gata_gtp_pretrain_dgi_seen_model'. When L30 is True.
#   L84:  3/7/5/9 correspond to the 1/2/3/4 in paper
#   L85:  1/20/100
#   L125: False/True
# To train
python train_rl_with_discrete_belief.py configs/train_gata_gtp_rl.yaml

Monitoring training progress

To monitor training progress: set "visdom: True" in config_***.yaml under the general section, and start Visdom in another terminal using the visdom command line. Then, open the link displayed by Visdom in your browser.

Citation

Please use the following bibtex entry:

@article{adhikari2020gata,
  title={Learning Dynamic Belief Graphs to Generalize on Text-Based Games},
  author={Adhikari, Ashutosh and Yuan, Xingdi and C\^ot\'{e}, Marc-Alexandre and Zelinka, Mikul\'{a}\v{s} and Rondeau, Marc-Antoine and Laroche, Romain and Poupart, Pascal and Tang, Jian and Trischler, Adam and Hamilton, William L.},
  journal={CoRR},
  volume={abs/2002.09127},
  year= {2020},
  archivePrefix={arXiv},
  eprint={2002.09127}
}

License

MIT

About

License:Other


Languages

Language:Python 98.8%Language:Jupyter Notebook 1.2%