erdalsahan / tabi

Code release for Type-Aware Bi-Encoders for Open-Domain Entity Retrieval

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TABi: Type-Aware Bi-Encoders for Open-Domain Entity Retrieval

This repo contains an implementation of TABi, a bi-encoder for entity retrieval that trains over knowledge graph types and unstructured text. TABi introduces a type-enforced contrastive loss to encourage query and entity embeddings to cluster by type in the embedding space. You can find more details in our paper.

This repo also includes pre-trained TABi models to retrieve Wikipedia pages from queries and training scripts to train TABi models on new datasets.

Setup

1. Installation

Our code is tested on Python 3.7. We recommend installing with a virtualenv.

pip install -r requirements.txt
pip install -e . 

If you are using NVIDIA A100 GPUs, you will need to install a version of PyTorch that supports the sm_80 CUDA architecture:

pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/torch_stable.html

2. Download TABi models

We provide the following pre-trained TABi models. We also provide the pre-computed entity embeddings over the KILT-E knowledge base. The pre-computed entity embeddings require 16GB of disk space to download.

We provide models that are trained with types using the type-enforced contrastive loss and models that are trained without types. Note that the models trained on the BLINK training data require mention boundaries (or mention detection, e.g. via flair) at test time. This is because all examples in the BLINK training data have mention boundaries.

Training Data Trained with Types Weights Entity Embs
KILT Yes url url
KILT No url url
BLINK Yes url url
BLINK No url url

See our paper for hyperparameter settings for the pre-trained models.

3. Download the knowledge base (KILT-E)

We use a filtered version of the KILT knowledge base. We remove Wikimedia internal items (e.g. disambiguation pages, list articles) and add FIGER types to entities where available. The final knowledge base, KILT-Entity (KILT-E) has 5.45 million entities corresponding to English Wikipedia pages.

Download KILT-E:

Both formats can be used for entity_file in the following commands, but the pickle will load a bit faster.

Use TABi interactively

We support two modes to use TABi interactively. We recommend using the models trained on KILT for the interactive mode. The interactive mode does not currently support mention detection or providing mention boundaries.

Standard retrieval mode

To retrieve entities from a pre-computed entity index, run:

python scripts/demo.py --model_checkpoint best_model.pth --entity_emb_path embs.npy --entity_file entity.pkl

Example:

To control the number of retrieved entities, use the flag --top_k. By default, the top 10 entities will be returned.

Entity-input mode

To input your own entities (title and description) and get a score between the query and your entity, simply provide the model checkpoint:

python scripts/demo.py -model_checkpoint best_model.pth

Example:

The scores provided are cosine similarities and will be between -1 and 1 (1 is most similar). The demo will continue to prompt you for entities. To enter a new query and entities, type exit.

Prepare data for TABi

We include AmbER and KILT datasets for evaluation and BLINK and KILT datasets for training in the TABi data format in Datasets. If you plan to use our provided datasets, you can skip to Evaluation and Training.

TABi data format

We require that the input to TABi be in the following format:

{
    "id": # unique example id  
    "text": # question or sentence 
    "label_id": # list of gold knowledge base ids if available, otherwise use [-1]
    "alt_label_id": # list of lists of alternate gold knowledge base ids, if none use [[]] 
    "mentions": # list of character spans of mention boundaries if available, otherwise [] 
}

Example (from Natural Questions):

{
    "id": "-143054837169120955",
    "text": "where are the giant redwoods located in california", 
    "label_id": [903760], 
    "alt_label_id": [[4683290, 2526048, 242069]], 
    "mentions": []
}

Note that if providing mention spans, TABi currently only supports disambiguating one mention at a time and will run separate evaluation queries on the model for each mention span in the list.

Preprocessing

To convert a jsonlines file in the KILT data format to the TABi data format, run:

python scripts/preprocess_kilt.py --entity_file entity.pkl --input_file nq-dev-kilt.jsonl --output_file nq-dev-tabi.jsonl

To convert a directory of KILT-formatted files to the TABi format, run:

python scripts/preprocess_kilt.py --entity_file entity.pkl --input_dir kilt_dev --output_dir kilt_dev_tabi

Evaluate TABi

The evaluation script runs the model eval, reports accuracy@1 and accuracy@10, and saves the predictions in KILT-formatted files.

To evaluate a TABi model, run:

python tabi/eval.py --test_data_file nq-dev-kilt.jsonl --entity_file entity.pkl --model_checkpoint best_model.pth --entity_emb_path embs.npy --mode eval --log_dir logs
  • log_dir specifies where the log file and prediction file are written.
  • You can also specify the name for the prediction file with --pred_file. For instance:
python tabi/eval.py --test_data_file nq-dev-kilt.jsonl --entity_file entity.pkl --model_checkpoint best_model.pth --entity_emb_path embs.npy --mode eval --log_dir logs --pred_file nq-dev-preds.jsonl

For benchmarks, we use the evaluation scripts provided by AmbER and KILT to report final numbers.

Train your own TABi model

Training consists of a multi-step procedure.

  1. Train with local in-batch negatives.
  2. Extract entity embeddings.
  3. Extract hard negatives using nearest neighbor search with optional hard negative filtering.
  4. Train with in-batch negatives and hard negatives.

An example script is in scripts/run_sample.py. To run with the small sample data in the repo on a GPU:

python scripts/run_sample.py

To run with the small sample data in the repo on a CPU:

python scripts/run_sample_cpu.py

Train on a new dataset

To train a new TABi model on your own dataset, make sure to format your training, eval, and test datasets in the TABi data format and modify data_dir, train_file, dev_file, and test_file in the example script.

To use a new entity knowledge base, each entity in the knowledge base (jsonlines file) should have the following format:

{
    "label_id": # unique id of the entity (optional, if not provided, row in knowledge base is assigned as the id)
    "title": # title of the entity 
    "text": # description of the entity 
    "types": # list of types ([] if none) 
    "wikipedia_page_id": # wikipedia page id (can exclude if not linking to Wikipedia) 
}

See KILT-E knowledge base for an example of the expected format. The type-enforced contrastive loss uses query types, which are assigned as the types associated with the gold entity for the query. It is important that the "types" are not all empty in the knowledge base in order to see benefits from the type-enforced contrastive loss. Make sure to update entity_file in the example script to use your new knowledge base.

Train on KILT or BLINK datasets

We provide example scripts to train a new TABi model on the BLINK and KILT datasets. The datasets for training can be downloaded below. The provided pre-trained models were trained on 16 A100 GPUs for four epochs, which took approximately 9 and 11 hours total for the BLINK and KILT datasets, respectively.

Distributed training

We support DistributedDataParallel training on a single node with multiple GPUs. See the example scripts above for training on BLINK and KILT data using distributed training. You may need to increase the ulimit (number of open files) on your machine for large datasets using ulimit -n 100000.

Filtering hard negatives by frequency

We have support for filtering hard negatives, following the procedure described in Botha et al.. The goal is to balance the frequency an entity occurs as a hard negative relative to the frequency an entity occurs in the training dataset as a gold entity. Filtering can help reduce the proportion of hard negatives that are rare entities. To use filtering, we provide the --filter_negatives flag. We only recommend this frequency-based filtering procedure for large training datasets (e.g. BLINK or KILT). On small training datases, most entities may have very low or zero counts, leading to aggressive filtering.

Datasets

Evaluation

We provide evaluation files in the TABi data format for the AmbER and KILT benchmarks. For KILT, we include the 8 open-domain tasks:

The ids for the dev/test splits we use for AmbER in our paper can be found here.

Training

We provide training and validation files in the TABi data format for:

Citation

If you find this code useful, please cite the following paper:

@inproceedings{leszczynski-etal-2022-tabi,
    title={{TAB}i: {T}ype-Aware Bi-Encoders for Open-Domain Entity Retrieval}, 
    author={Megan Leszczynski and Daniel Y. Fu and Mayee F. Chen and Christopher R\'e}, 
    booktitle={Findings of the Association of Computational Linguistics: ACL 2022}, 
    year={2022}
}

Acknowledgments

Our work was inspired by the following repos:

About

Code release for Type-Aware Bi-Encoders for Open-Domain Entity Retrieval

License:Apache License 2.0


Languages

Language:Python 100.0%