rennancl / transformer_rankers

A library to conduct ranking experiments with transformers.

Home Page:https://guzpenha.github.io/transformer_rankers/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Transformer-Rankers

Documentation license

A library to conduct ranking experiments with transformers.

Setup

The following will clone the repo, install a virtual env and install the library with the requirements.

#Clone the repo
git clone https://github.com/Guzpenha/transformer_rankers.git
cd transformer_rankers    

#Optionally use a virtual enviroment
python3 -m venv env
source env/bin/activate

#Optionally use a virtual enviroment
pip install -e .
pip install -r requirements.txt

Example: BERT-ranker for dialogue

The following example uses BERT for the task of conversation response ranking using MANtIS corpus. We can download the data as follows:

from transformer_rankers.datasets import downloader

#Download the data with DataDownloader
data_folder = "data"
dataDownloader = downloader.DataDownloader("mantis", data_folder)
dataDownloader.download_and_preprocess()

And train BERT for pointwise learning to rank with randomly sampled negative samples:

from transformer_rankers.models import pointwise_bert
from transformer_rankers.trainers import transformer_trainer
from transformer_rankers.datasets import dataset, preprocess_crr
from transformer_rankers.negative_samplers import negative_sampling 
from transformer_rankers.eval import results_analyses_tools

#Load the dataset
task = "mantis"
train = pd.read_csv(data_folder+task+"/train.tsv", sep="\t")
valid = pd.read_csv(data_folder+task+"/valid.tsv", sep="\t")

#Instantiate random negative samplers (1 for training 9 negative candidates for test)
# the library also supports BM25 and sentenceBERT negative samplers.
ns_train = negative_sampling.RandomNegativeSampler(list(train["response"].values), 1)
ns_val = negative_sampling.RandomNegativeSampler(list(valid["response"].values) + \
    list(train["response"].values), 9)

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
special_tokens_dict = {'additional_special_tokens': ['[UTTERANCE_SEP]', '[TURN_SEP]'] }
tokenizer.add_special_tokens(special_tokens_dict)

#Create the loaders for the datasets, with the respective negative samplers        
dataloader = dataset.QueryDocumentDataLoader(train_df=train, val_df=valid, test_df=valid,
                                tokenizer=tokenizer, negative_sampler_train=ns_train, 
                                negative_sampler_val=ns_val, task_type='classification', 
                                train_batch_size=6, val_batch_size=6, max_seq_len=512, 
                                sample_data=-1, cache_path="{}/{}".format(data_folder, task))

train_loader, val_loader, test_loader = dataloader.get_pytorch_dataloaders()


model = pointwise_bert.BertForPointwiseLearning.from_pretrained('bert-base-cased')
# we added [UTTERANCE_SEP] and [TURN_SEP] to the vocabulary so we need to resize the token embeddings
model.resize_token_embeddings(len(dataloader.tokenizer)) 

#Instantiate trainer that handles fitting.
trainer = transformer_trainer.TransformerTrainer(model=model,train_loader=train_loader,
                                val_loader=val_loader, test_loader=test_loader, 
                                num_ns_eval=9, task_type="classification", tokenizer=tokenizer,
                                validate_every_epoch=1, num_validation_batches=-1,
                                num_epochs=1, lr=0.0005, sacred_ex=None,
                                validate_every_steps=-1, num_training_instances=-1)

#Train the model
logging.info("Fitting BERT-ranker for MANtIS")
trainer.fit()

#Predict for test (in our example the validation set)
logging.info("Predicting")
preds, labels, _ = trainer.test()
res = results_analyses_tools.\
    evaluate_and_aggregate(preds, labels, ['ndcg_cut_10'])

for metric, v in res.items():
    logging.info("Test {} : {:4f}".format(metric, v))

Examples

Open In Colab Fine tune pointwise BERT for Community Question Answering.

Wandb report Wandb report of fine tunning BERT for Conversation Response Ranking.

About

A library to conduct ranking experiments with transformers.

https://guzpenha.github.io/transformer_rankers/

License:MIT License


Languages

Language:Python 96.4%Language:Shell 2.5%Language:Makefile 0.5%Language:Batchfile 0.4%Language:HTML 0.2%