Paraphrase demo will not work with transformers version bigger than 4
hololeac opened this issue · comments
Using transformers library version 4+ for running demo_paraphraser.py will throw an error after entering a sentence. The problem is not present on transformers version 3.4.0
Enter your sentence, q to quit: So can a magazine survive by downright thumbing its nose at major advertisers?
Traceback (most recent call last):
File "drive/MyDrive/style-transfer-paraphrase-master/demo_paraphraser.py", line 26, in <module>
greedy_decoding = paraphraser.generate(input_sentence)
File "/content/drive/MyDrive/style-transfer-paraphrase-master/style_paraphrase/inference_utils.py", line 129, in generate
top_p=top_p)[0][0]
File "/content/drive/MyDrive/style-transfer-paraphrase-master/style_paraphrase/inference_utils.py", line 102, in generate_batch
top_p=top_p
File "/content/drive/MyDrive/style-transfer-paraphrase-master/style_paraphrase/utils.py", line 173, in generate
interpolation=interpolation
File "/content/drive/MyDrive/style-transfer-paraphrase-master/style_paraphrase/utils.py", line 272, in sample_sequence
next_token_logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.)
TypeError: string indices must be integers
I was running into the same issue. Fortunately, the fix for this error is simple:
It seems like the way the logits and past keys are accessed changed between transformer versions. Before, it was accessed by index. Now, it is accessed like a dictionary.
For transformers version 4+ in style_paraphrase/utils.py
you can just change the get_logits
method to the following:
def get_logits(model, iteration, generated, segments, style_content_vectors, past):
if iteration == 0:
if style_content_vectors is None:
pred = model(
input_ids=generated,
token_type_ids=segments
)
else:
pred = model(
input_ids=generated,
token_type_ids=segments,
prefix_input_vectors=style_content_vectors
)
else:
# used the cached representations to speed up decoding
pred = model(
input_ids=generated[:, -1:],
token_type_ids=segments[:, -1:],
past_key_values=past
)
logits = pred['logits']
past = pred['past_key_values']
return logits, past
Hi @philno, this solution works great! Do you want to submit a PR to fix it?
Yes, I can submit a PR. I just need to make sure that I don't break the code for transformers version < 4.
Fixed in #11.