OpenNMT / CTranslate2

Fast inference engine for Transformer models

Home Page:https://opennmt.net/CTranslate2

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Proper way to change alignment_heads and alignment_layer when using return_attention=True for converted HF Seq2Seq Transformers?

jorirsan opened this issue · comments

Hello,

I'm currently using Ctranslate2 to quickly obtain attention scores for a downstream task, and I am interested obtaining the optimal "alignment_heads" and "alignment_layer" for HuggingFace Seq2Seq converted models. However, I am not quite figuring out if it's currently possible to do so.

I've tried to change the CTranslate2 config.json of HF converted models (NLLB, MADLAD) and added "alignment_heads" and "alignment_layer" values to the conf, but it doesn't seem to have any change in the "return_attention" results. Looking at the other converter code backends (Fairseq, Marian) , both options seem to be available, but for HF models
it seems that these options are mostly thought out for usage with Whisper

I am correct to assume that these parameters are baked in to the model.bin at the point of conversion? If so, is there current way to change them before converting the HF models? Would manually adding "alignment_heads" pairs to the HuggingFace "config.json" work? Or is there any way to directly modify the model specification parameters through the Translator object?

Modified Madlad config.json:

{
  "add_source_bos": false,
  "add_source_eos": false,
  "bos_token": "<s>",
  "decoder_start_token": "<unk>",
  "eos_token": "</s>",
  "layer_norm_epsilon": null,
  "multi_query_attention": false,
  "unk_token": "<unk>",
  "alignment_heads": 4,
  "alignment_layer": 6
}

Test script:

import ctranslate2
import transformers

# Model transformed with ct2-transformers-converter with no extra options

src = "<2es> This is a test translation."

model = "madlad400-3b-mt"
model_hf = "google/madlad400-3b-mt"
translator = ctranslate2.Translator(model)

tokenizer = transformers.AutoTokenizer.from_pretrained(model_hf)

source = tokenizer(src).tokens()
results = translator.translate_batch([source], return_attention=True, return_scores=True)

print(results)
Output for both model with and without adding "alignment_heads" and "alignment_layer" for the Ctranslate2 config.json.

[TranslationResult(hypotheses=[['▁Esta', '▁es', '▁una', '▁traduc', 'ción', '▁de', '▁prueba', '.']], scores=[-0.24390581250190735], attention=[[[0.6991117596626282, 0.09015607833862305, 0.0745280459523201
, 0.01642746664583683, 0.013331991620361805, 0.00797318760305643, 0.029338935390114784, 0.02415921539068222, 0.044973284006118774], [0.7786989212036133, 0.08919020742177963, 0.010596153326332569, 0.01059
800200164318, 0.012276791967451572, 0.01970701850950718, 0.0521710179746151, 0.010949257761240005, 0.01581263355910778], [0.6620574593544006, 0.1619597226381302, 0.016302023082971573, 0.0311853289604187,
 0.03139425441622734, 0.01271171122789383, 0.026976067572832108, 0.015252464450895786, 0.04216107353568077], [0.6986892223358154, 0.10730499029159546, 0.00985235907137394, 0.015372445806860924, 0.0231946
33424282074, 0.04971150681376457, 0.0636223703622818, 0.01495570968836546, 0.01729676127433777], [0.882076621055603, 0.0804339274764061, 0.004278178326785564, 0.006933772936463356, 0.004655494354665279, 
0.005530696362257004, 0.006705072708427906, 0.0036677366588264704, 0.005718609318137169], [0.695141613483429, 0.10289674997329712, 0.014959929510951042, 0.0354330874979496, 0.03752530738711357, 0.0278607
94216394424, 0.020848261192440987, 0.0317409411072731, 0.03359336405992508], [0.6272260546684265, 0.07765787839889526, 0.018862776458263397, 0.030511993914842606, 0.02709764987230301, 0.02179295569658279
4, 0.09415297210216522, 0.05251908302307129, 0.05017860606312752], [0.759211003780365, 0.0940946713089943, 0.009989023208618164, 0.02316822111606598, 0.02662344090640545, 0.009856035001575947, 0.01309716
3289785385, 0.015567192807793617, 0.04839325696229935]]])]

Library Versions:

ctranslate2              4.2.1
transformers             4.41.1