GPT-2 torch argmax issues
mlpro1122 opened this issue · comments
`import torch
tokenizer_custom = GPT2Tokenizer.from_pretrained("gpt2")
model_custom = GPT2LMHeadModel.from_pretrained('gpt2')
generated_custom = tokenizer_custom.encode("The Manhattan bridge")
context_custom = torch.tensor([generated_custom])
past_custom = None
for j in range(100):
print(j)
output_custom, past_custom = model_custom(context_custom, past=past_custom)
token_custom = torch.argmax(output_custom[..., -1, :])
generated_custom += [token_custom.tolist()]
context_custom = token_custom.unsqueeze(0)
sequence_custom = tokenizer_custom.decode(generated_custom)
print(sequence_custom)`
Please help me correct this
Please use this code for the text generation. make sure you have transformer and torch is installed.
pip install transformers
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
generated = tokenizer.encode("The Manhattan bridge", return_tensors='pt')
output_sequences = model.generate(
input_ids=generated,
max_length=150,
num_return_sequences=1,
no_repeat_ngram_size=2,
temperature=0.7,
top_k=50,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id
)
text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
print(text)
Note: You can fine tune the parameters for the best answer specially temperature.
Thanks for the prompt reply and help let me check
Its working fine and thanks for help. Although results are not promising.