Zaker237 / ranking-utils

Miscellaneous utilities for ranking models

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ranking-utils

This repository provides miscellaneous utilities for ranking models.

Installation

Clone this repository and run:

python -m pip install .

Usage

Dataset Pre-Processing

The dataset pre-processing script reads a dataset and creates training, dev and test sets (hdf5 format) that can be used by the ranking models. Run the script as follows to see available options:

python -m ranking_utils.scripts.create_h5_data

The following datasets are currently supported:

The script uses Hydra. Refer to the documentation for detailed instructions on how to configure arguments.

Example

The following pre-processes the ANTIQUE dataset:

python -m ranking_utils.scripts.create_h5_data \
    dataset=antique \
    dataset.root_dir=/path/to/antique/files \
    hydra.run.dir=/path/to/output/files

In order to see all available options for a dataset, run:

python -m ranking_utils.scripts.create_h5_data \
    dataset=antique \
    --help

Ranking

Implementing a ranker requires two components:

  1. A DataProcessor (specific to your model) subclasses ranking_utils.model.data.DataProcessor and implements the following methods:
    • get_model_input(self, query: str, doc: str) -> ModelInput: Transforms a query-document pair into an input that is suitable for the model.
    • get_model_batch(self, inputs: Iterable[ModelInput]) -> ModelBatch: Creates a model batch from multiple inputs.
  2. The ranking model itself subclasses ranking_utils.model.Ranker and implements the following methods:
    • forward(self, batch: ModelBatch) -> torch.Tensor: Computes query-document scores, output shape (batch_size, 1).
    • configure_optimizers(self) -> Tuple[List[Any], List[Any]]: Configures optimizers (and schedulers). Refer to the PyTorch Lightning documentation.

You can then train your model using the usual PyTorch Lightning setup. For example:

from pathlib import Path
from pytorch_lightning import Trainer
from ranking_utils.model.data import H5DataModule
from ranking_utils.model import TrainingMode
from my_ranker import MyRanker, MyDataProcessor

data_module = H5DataModule(
    data_processor=MyDataProcessor(...),
    data_dir=Path(...),
    fold_name="fold_0",
    batch_size=32
)
model = MyRanker(...)
data_module.training_mode = model.training_mode = TrainingMode.PAIRWISE
model.pairwise_loss_margin = 0.2
Trainer(...).fit(model=model, datamodule=data_module)

Validation

After each epoch, the ranker automatically computes the following ranking metrics on the validation set:

  • val_RetrievalMAP: mean average precision
  • val_RetrievalMRR: mean reciprocal rank
  • val_RetrievalNormalizedDCG: nDCG

These can be used in combination with callbacks, e.g. early stopping.

Examples

Example implementations of various models using this library can be found here.

About

Miscellaneous utilities for ranking models

License:MIT License


Languages

Language:Python 100.0%