krishnap25 / mauve

Package to compute Mauve, a similarity score between neural text and human text. Install with `pip install mauve-text`.

Home Page:https://krishnap25.github.io/mauve/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Allowing other models for extracting features

jinyongyoo opened this issue · comments

Hello!

First off, thanks for sharing the code. In the paper, it says that MAUVE works with other embedding models. Therefore, I wanted to try out models such as DialoGPT from Microsoft. But in the code, it limits the model and tokenizer name to "gpt2" family. I think it would better we remove this restriction since others might also want to try out other models.

If you want, I can make a PR to change this.

mauve/src/mauve/utils.py

Lines 25 to 39 in b3c01d5

def get_model(model_name, tokenizer, device_id):
device = get_device_from_arg(device_id)
if 'gpt2' in model_name:
model = AutoModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id).to(device)
model = model.eval()
else:
raise ValueError(f'Unknown model: {model_name}')
return model
def get_tokenizer(model_name='gpt2'):
if 'gpt2' in model_name:
tokenizer = AutoTokenizer.from_pretrained(model_name)
else:
raise ValueError(f'Unknown model: {model_name}')
return tokenizer

Hi @jinyongyoo, a PR for this would be fantastic. Thanks!