MostAwesomeDude / retro

On the Generalization Ability of Retrieval-Enhanced Transformers

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

On the Generalization Ability of Retrieval-Enhanced Transformers

This is the official repo to the paper On the Generalization Ability of Retrieval-Enhanced Transformers. We release our RETRO implementation along with our trained model. Due to the large size, we can unfortunately not host the data + retrieval index, but provide the code for reproducing from the raw Pile and RealNews.

Environment

All code and commands in this repo should be executed within the provided Docker environment. To build and start the container in a terminal, run:

$ ./start.sh [--gpu]

...

docker-user@91711f9b80b8:/workspace$ 

It might take several minutes to build the Docker image the first time.

VS Code integration

You can alternatively use the Dev Containers extension to VS Code. Open this folder in VS Code and click "Reopen in container" and VS Code will do the rest.

Model download

Download the retro.zip and extract it in data/model folder.

Usage

To generate from RETRO, run:

$ cd src/
$ python generate_retro.py \
    --retro-config /workspace/data/model/retro.json \
    --checkpoint /workspace/data/model/model.ckpt \
    --prompt "A retrieval-enhanced language model is" \
    --num-neighbours 1 \
    --num-continuation-chunks 1

You will be prompted to input the neighbour chunks throughout the generation.

Retrieval data

Instructions for creating a custom retrieval dataset or re-building MassiveOpenText are provided in data/datasets/README.md.

Training RETRO

Our RETRO model was trained with the following command, on four nodes with 4 A100 40GB each. You may have to modify the flags depending on your resource availability.

$ cd src/
$ python train_retro.py \
	--training-dataset-spec ../data/datasets/MassiveOpenText/train_sentence_transformer_neighbours.spec.json \
	--validation-dataset-spec ../data/datasets/MassiveOpenText/val_sentence_transformer_neighbours.spec.json \
	--experiment-dir ../data/model/ \
	--num-neighbours 2 \
	--num-continuation-chunks 1 \
	--max-len 1024 \
	--retro-config ../data/model/retro.json \
	--batch-size 2 \
	--accumulate-grad-batches 4 \
	--gpus-per-node 4 \
	--num-nodes 4

Running tests

To run tests for validating our RETRO model implementation, run:

$ cd src/
$ pytest

Citation

@misc{https://doi.org/10.48550/arxiv.2302.12128,
  doi = {10.48550/ARXIV.2302.12128},
  url = {https://arxiv.org/abs/2302.12128},
  author = {Norlund, Tobias and Doostmohammadi, Ehsan and Johansson, Richard and Kuhlmann, Marco},
  keywords = {Computation and Language (cs.CL), FOS: Computer and information sciences, FOS: Computer and information sciences},
  title = {On the Generalization Ability of Retrieval-Enhanced Transformers},
  publisher = {arXiv},
  year = {2023},
  copyright = {Creative Commons Attribution 4.0 International}
}

About

On the Generalization Ability of Retrieval-Enhanced Transformers

License:Apache License 2.0


Languages

Language:Python 91.6%Language:Makefile 5.4%Language:Shell 2.1%Language:Dockerfile 0.9%