chaoweihuang / knn-models

A retrieval augmented sequence modeling toolkit implemented based on Fairseq

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

kNN-models

Implemented PapersRequirements and InstallationGetting StartedBenchmarksAcknowledgements

What's New

  • 2022/10/04 kNN-models is publicly available

Overview

kNN-models is a k-nearest neighbor augmented sequence modeling toolkit implemented based on Fairseq. It enhances the pre-trained neural sequence to sequence model by retrieving from the external memory without expensive retraining.

Main features:

  • Fast and memory efficient (please see benchmarks for details)
  • Provide reference implementation of various k-nearest neighbor augmented sequence modeling papers (please see Implemented-papers for details)
  • Compatible with most of the pre-trained models in Fairseq (although only the transformer model has been well tested yet, we plan to conduct experiments with other models in the future)
  • Support similarity search with Faiss and Elasticsearch (retrieving with Elasticsearch is an upcoming feature, it is still underdeveloped at the es branch and will merge into the main branch in the foreseeable future)
  • The Faiss index can be placed on a GPU that is different from the one occupied by the model and sharded between multiple GPUs to avoid out of memory
  • The module which produces the intermediate hidden state to serve as datastore keys can be configured through command line arguments to adapt to the user's needs (it is the last layer in the decoder by default, please see the BaseKnnConfig for details)
  • Flexible configuration based on Hydra

Implemented Papers

The repository contains the reference implementation of following papers (sorted by publication date):

The detailed READMEs about how to reproduce them with kNN-models can be found in the examples folder.

Requirements and Installation

The repository is developed and tested on Python 3.10, PyTorch 1.10.0, Fairseq 0.12.1, and Faiss-gpu 1.7.2. We recommend users keep the versions of these packages the same as ours to alleviate the compatibility issues, even though other versions may also work.

To install kNN-models and develop locally:

git clone https://github.com/cordercorder/knn-models
cd knn-models
pip install -e ./

Note that pip install -e ./ will check the packages in the Python environment to resolve the dependencies specified in requirements.txt. However, Faiss installed through conda can not be identified by pip, which will result in the redundant Faiss installation from PIP source. If you are pretty sure that all the packages required by this repository are installed well, you can run python setup.py develop to install kNN-models instead.

Getting Started

We try to make the implementation independent of the model architecture during developing this repository. Consequently, we extend the task in Fairseq with the ability to perform similarity search. As the task can be combined with different model architectures, we can enhance various pre-trained models with the external memory without modifying the official code of Fairseq. For example, the kNN-MT can be implemented with just a few lines of code like the following:

from functools import partial
from dataclasses import dataclass
from fairseq.tasks.translation import (
    TranslationTask,
    TranslationConfig,
)
from fairseq.tasks import register_task
from fairseq.dataclass import FairseqDataclass
from knn_models.dataclass import KnnConfig
from knn_models.hook_utils import ForwardHook
from knn_models.knn_utils import (
    KnnSearch,
    get_captured_module,
    get_normalized_probs,
)


@dataclass
class TranslationKnnConfig(TranslationConfig):
    """config for nearest neighbor machine translation"""
    knn_config: KnnConfig = KnnConfig()


@register_task("translation_knn", dataclass=TranslationKnnConfig)
class TranslationKnnTask(TranslationTask):
    """task for nearest neighbor machine translation"""
    def __init__(self, cfg: TranslationKnnConfig, src_dict, tgt_dict):
        super().__init__(cfg, src_dict, tgt_dict)
        self.knn_search = KnnSearch(cfg.knn_config)
        self.forward_hook = ForwardHook()

    def build_model(self, cfg: FairseqDataclass, from_checkpoint=False):
        model = super().build_model(cfg, from_checkpoint)

        assert hasattr(model, "decoder"), \
            "TranslationKnnTask only supports the model with decoder! " \
            f"There is no decoder in {model.__class__.__name__}."
        
        # collect outputs from the specified module in decoder as the datastore keys
        captured_module_name = self.cfg.knn_config.module_to_capture
        captured_module = get_captured_module(model.decoder, captured_module_name)
        captured_module.register_forward_hook(self.forward_hook.forward_hook_function)

        # rewrite `get_normalized_probs` function to support kNN augmented NMT
        model.get_normalized_probs = partial(get_normalized_probs, self, model)
        return model

Benchmarks

We measured the generation speed and GPU memory consumption during inference to evaluate the performance of kNN-models. We conducted experiments on kNN-MT and Adaptive kNN-MT considering that they are dominant approaches to enabling retrieval argumented MT.

Following the common practice, we used the multi-domain dataset (Koehn & Knowles, 2017) which was re-split by Aharoni & Goldberg (2020) for experiments and the WMT’19 German-English news translation task winner model (Ng et al., 2019) was adopted as the pre-trained NMT model. For kNN-MT, we tuned the hyperparameters (num_neighbors, lambda, temperature) on the validation sets according to the BLEU score. The hyperparameters for Adaptive kNN-MT were inherited from kNN-MT except for lambda, which can be inferred from the Meta-k-Network of Adaptive kNN-MT. We employed beam search with a beam size of 5 and a length penalty of 1.0 during decoding. It is worth noting that only one GPU was used throughout the benchmark experiments and the Faiss index was placed on GPU to speed up the search operation.

The datastore size and the hyperparameters for each domain are presented below:

Medical Law IT Koran Subtitles
datastore size 6501418 18857646 3449918 519897 6209620
num_neighbors 8 8 16 16 16
lambda 0.7 0.7 0.6 0.7 0.5
temperature 5 5 5 20 20

The BLEU score of the pre-trained NMT model (Base MT), kNN-MT, and Adaptive kNN-MT on the test sets for each domain are presented below:

Medical Law IT Koran Subtitles
Base MT 41.87 45.96 38.52 17.07 29.39
kNN-MT 57.08 62.48 47.1 22.54 30.55
Adaptive kNN-MT 58.17 63.32 48.33 22.03 30.45

Generation Speed

As the generation speed usually varies between different runs and is highly dependent on the hardware environment, we performed each experiment 5 times and reported the mean and standard deviation of the generation speed on two different servers respectively.

The generation speed (token/s) of kNN-models on a server with 8 NVIDIA Tesla P100 GPUs (16GB), 2 Intel Xeon Gold 6240 CPUs, and 256 GB of RAM is presented below (as there are sentences with more than 400 tokens in the test sets of medical and law domains, the generation speed is not available in the case of batch size set to 400):

Batch Size Medical Law IT Koran Subtitles
400 tokens Base MT N/A N/A 593.67±12.92 577.60±14.76 1005.69±44.67
kNN-MT N/A N/A 492.66±21.24 488.79±20.47 858.08±29.71
Adaptive kNN-MT N/A N/A 470.20±20.02 455.39±16.95 806.94±24.71
800 tokens Base MT 761.39±29.74 705.84±7.99 869.02±36.63 830.49±34.10 1502.55±29.31
kNN-MT 625.08±24.04 542.48±21.85 738.49±31.51 689.17±36.21 1240.48±21.99
Adaptive kNN-MT 591.90±16.39 521.86±12.26 710.79±17.69 642.82±20.04 1190.69±15.46
1600 tokens Base MT 1033.93±30.34 1000.80±34.31 1195.03±41.52 1138.84±41.03 1859.79±10.62
kNN-MT 829.28±22.33 743.36±23.23 993.22±22.14 960.69±27.82 1467.16±4.67
Adaptive kNN-MT 812.92±13.07 715.14±18.86 924.22±22.44 903.87±16.43 1408.14±16.42
3200 tokens Base MT 1335.80±20.57 1294.52±15.47 1445.16±20.55 1497.09±16.30 2047.57±19.40
kNN-MT 1046.16±16.05 940.59±9.40 1197.04±18.48 1247.45±17.36 1586.45±10.99
Adaptive kNN-MT 1036.07±3.97 917.63±10.08 1189.73±5.70 1203.48±9.22 1577.00±12.18
6400 tokens Base MT 1563.36±11.48 1522.87±11.01 1613.63±17.39 1716.00±11.16 2126.56±19.66
kNN-MT 1226.55±3.98 1072.35±5.72 1323.60±14.69 1447.19±13.10 1660.31±15.97
Adaptive kNN-MT 1193.37±13.58 1043.77±6.62 1293.78±11.54 1408.91±7.27 1648.06±17.63
12800 tokens Base MT 1675.49±9.45 1633.76±9.67 1647.95±12.20 1803.01±10.18 2197.24±13.67
kNN-MT 1300.68±6.27 1140.59±3.88 1334.90±2.23 1532.65±8.40 1694.99±7.50
Adaptive kNN-MT 1275.62±10.28 1125.35±5.66 1323.47±9.31 1500.19±10.48 1699.80±10.55

The generation speed (token/s) of kNN-models on a server with 8 NVIDIA GeForce GTX TITAN GPUs (24GB), 2 Intel Xeon E5-2680 CPUs, and 256 GB of RAM is presented below:

Batch Size Medical Law IT Koran Subtitles
400 tokens Base MT N/A N/A 435.83±15.51 432.85±16.09 844.25±57.33
kNN-MT N/A N/A 408.02±21.15 403.94±16.99 759.71±51.01
Adaptive kNN-MT N/A N/A 393.35±25.35 371.31±29.31 724.04±42.07
800 tokens Base MT 634.81±15.64 588.01±14.00 743.54±42.92 682.80±19.63 1507.27±54.44
kNN-MT 542.13±11.21 481.48±8.66 651.12±31.04 618.70±11.19 1261.36±44.09
Adaptive kNN-MT 526.43±33.34 436.25±21.67 633.04±29.44 556.48±35.99 1244.21±69.26
1600 tokens Base MT 967.79±14.60 983.15±9.54 1110.93±25.45 1088.76±41.47 2182.40±74.34
kNN-MT 761.56±33.66 726.35±25.67 1040.71±17.07 919.17±31.14 1664.39±55.27
Adaptive kNN-MT 745.29±21.61 719.38±27.49 969.04±46.21 915.46±52.70 1601.80±38.00
3200 tokens Base MT 1526.37±43.21 1488.71±78.56 1665.54±66.93 1885.99±13.26 2645.62±80.18
kNN-MT 1168.07±20.86 1051.21±30.82 1395.36±63.48 1547.67±60.08 2040.28±29.90
Adaptive kNN-MT 1135.30±63.46 1037.96±54.62 1335.45±60.56 1442.43±52.53 2032.88±47.17
6400 tokens Base MT 2078.05±14.57 2038.81±60.04 2078.64±55.91 2397.98±11.12 2838.64±12.76
kNN-MT 1541.41±31.89 1337.22±5.74 1698.17±46.67 1965.55±43.59 2176.18±26.11
Adaptive kNN-MT 1494.57±22.87 1326.34±24.34 1695.56±42.75 1902.53±45.91 2173.67±25.10
12800 tokens Base MT 2377.90±20.36 2374.11±6.77 2158.86±21.50 2589.23±40.78 2986.30±31.20
kNN-MT 1752.04±11.44 1493.63±5.76 1772.20±51.73 2175.42±40.24 2314.58±6.86
Adaptive kNN-MT 1719.02±36.40 1476.38±13.23 1765.07±47.39 2117.49±45.74 2313.21±44.98

GPU Memory Consumption

It is nontrivial to accurately measure the minimum amount of GPU memory to support model inference due to the complicated GPU memory management of PyTorch and Faiss. Nevertheless, to report the approximate minimum GPU memory requirement for inference, we disabled the memory caching of PyTorch by setting the value of the environment variable PYTORCH_NO_CUDA_MEMORY_CACHING to 1 and monitored the maximum amount of used GPU memory every 10 milliseconds. We set the batch size to 12000 tokens to follow the default setting of Fairseq for experiments.

The observed maximum GPU memory consumption of kNN-models during inference is presented below:

Batch Size Medical Law IT Koran Subtitles
12000 tokens Base MT 6363 MB 6519 MB 6509 MB 6575 MB 6349 MB
kNN-MT 8391 MB 9383 MB 8255 MB 8155 MB 8367 MB
Adaptive kNN-MT 8379 MB 9403 MB 8265 MB 8153 MB 8375 MB

Acknowledgements

We are extremely grateful to the research communities for their incredible work on retrieval argumented sequence modeling. This repository would not have been possible without them. Furthermore, we would also like to thank wls for his generous help and valuable suggestions in replicating the PCKMT.

About

A retrieval augmented sequence modeling toolkit implemented based on Fairseq

License:Apache License 2.0


Languages

Language:Python 100.0%