tvergho / sentence-transformers-burn

Implementation of sentence embeddings with BERT in Rust, using the Burn library.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

license Rust Version PRs welcome

Sentence Transformers in Burn

This library provides an implementation of the Sentence Transformers framework for computing text representations as vector embeddings in Rust. Specifically, it uses the Burn deep learning library to implement the BERT model. Using Burn, this can be combined with any supported backend for fast, efficient, cross-platform inference on CPUs and GPUs. ST-Burn supports any state-of-the-art model that implements the BERT architecture.

Currently inference-only for now.

Features

  • Import models via safetensors (using Candle). 📦
  • Code structure replicates the official Huggingface BertModel implementation. 🚀
  • Flexible inference backend using Burn. 🔧

Installation

sentence-transformers-burn can be installed from source.

cargo add --git https://github.com/tvergho/sentence-transformers-burn.git sentence_transformers

Run cargo build to make sure everything can be correctly built.

cargo build

Note that building the burn-tch dependency may require manually linking Libtorch. After installing via pip:

export LIBTORCH=$(python3 -c 'import torch; from pathlib import Path; print(Path(torch.__file__).parent)')
# /path/to/torch

export DYLD_LIBRARY_PATH=/path/to/torch/lib

Python dependencies (for running the scripts in scripts/) should also be installed.

python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt

Usage

A BertModel can be loaded and initialized from a file, as in the example below:

use sentence_transformers::bert_loader::{load_model_from_safetensors, load_config_from_json};
use sentence_transformers::model::{
  bert_embeddings::BertEmbeddingsInferenceBatch,
  bert_model::BertModel,
};
use burn_tch::{TchBackend, TchDevice};
use burn::tensor::Tensor;

const BATCH_SIZE: u64 = 64;

let device = TchDevice::Cpu;
let config = load_config_from_json("model/bert_config.json");
let model: BertModel<_> = load_model_from_safetensors::<TchBackend<f32>>("model/bert_model.safetensors", &device, config);

let batch = BertEmbeddingsInferenceBatch {
  tokens: Tensor::zeros(vec![BATCH_SIZE, 256]).to_device(&device.clone()),
  mask_attn: Some(Tensor::ones(vec![BATCH_SIZE, 256]).to_device(&device.clone()))
};

model.forward(batch); // [batch_size, seq_len, n_dims]

sentence-transformers-burn also comes with a built-in inference server. To start, simply run:

cargo run --release --bin server -- path/to/model/dir

The model directory should contain a bert_model.safetensors and bert_config.json file. Once the server is running, inference can be initiated via POST request:

POST http://localhost:3030/embed

{
  "input_ids": [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],
  "attention_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
}

This will return a 3D array of floats of size [batch_size, seq_len, n_dims].

Testing

Tests can be run to verify that the Rust model output matches a comparable Huggingface model. To save a model to use during testing, run python scripts/prepare_test.py. Then, simply:

cargo run test

To Do

  • Cleaner model import (directly from safetensors/config.json)
  • Proper documentation and more testing
  • More model usage options (e.g. classification, NER, question answering heads)
  • GGML backend/quantization

About

Implementation of sentence embeddings with BERT in Rust, using the Burn library.

License:MIT License


Languages

Language:Rust 90.3%Language:Python 9.7%