martiansideofthemoon / style-transfer-paraphrase

Official code and data repository for our EMNLP 2020 long paper "Reformulating Unsupervised Style Transfer as Paraphrase Generation" (https://arxiv.org/abs/2010.05700).

Home Page:http://style.cs.umass.edu

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Paraphrase demo will not work with transformers version bigger than 4

hololeac opened this issue · comments

Using transformers library version 4+ for running demo_paraphraser.py will throw an error after entering a sentence. The problem is not present on transformers version 3.4.0

Enter your sentence, q to quit: So can a magazine survive by downright thumbing its nose at major advertisers?
Traceback (most recent call last):
  File "drive/MyDrive/style-transfer-paraphrase-master/demo_paraphraser.py", line 26, in <module>
    greedy_decoding = paraphraser.generate(input_sentence)
  File "/content/drive/MyDrive/style-transfer-paraphrase-master/style_paraphrase/inference_utils.py", line 129, in generate
    top_p=top_p)[0][0]
  File "/content/drive/MyDrive/style-transfer-paraphrase-master/style_paraphrase/inference_utils.py", line 102, in generate_batch
    top_p=top_p
  File "/content/drive/MyDrive/style-transfer-paraphrase-master/style_paraphrase/utils.py", line 173, in generate
    interpolation=interpolation
  File "/content/drive/MyDrive/style-transfer-paraphrase-master/style_paraphrase/utils.py", line 272, in sample_sequence
    next_token_logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.)
TypeError: string indices must be integers

I was running into the same issue. Fortunately, the fix for this error is simple:
It seems like the way the logits and past keys are accessed changed between transformer versions. Before, it was accessed by index. Now, it is accessed like a dictionary.
For transformers version 4+ in style_paraphrase/utils.py you can just change the get_logits method to the following:

def get_logits(model, iteration, generated, segments, style_content_vectors, past):
    if iteration == 0:
        if style_content_vectors is None:
            pred = model(
                input_ids=generated,
                token_type_ids=segments
            )
        else:
            pred = model(
                input_ids=generated,
                token_type_ids=segments,
                prefix_input_vectors=style_content_vectors
            )
    else:
        # used the cached representations to speed up decoding
        pred = model(
            input_ids=generated[:, -1:],
            token_type_ids=segments[:, -1:],
            past_key_values=past
        )
    logits = pred['logits']
    past = pred['past_key_values']
    return logits, past

Hi @philno, this solution works great! Do you want to submit a PR to fix it?

Yes, I can submit a PR. I just need to make sure that I don't break the code for transformers version < 4.

Fixed in #11.