mshumer / gpt-llm-trainer

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

GPT-2 torch argmax issues

mlpro1122 opened this issue · comments

`import torch

tokenizer_custom = GPT2Tokenizer.from_pretrained("gpt2")
model_custom = GPT2LMHeadModel.from_pretrained('gpt2')

generated_custom = tokenizer_custom.encode("The Manhattan bridge")
context_custom = torch.tensor([generated_custom])
past_custom = None

for j in range(100):
print(j)
output_custom, past_custom = model_custom(context_custom, past=past_custom)
token_custom = torch.argmax(output_custom[..., -1, :])

generated_custom += [token_custom.tolist()]
context_custom = token_custom.unsqueeze(0)
sequence_custom = tokenizer_custom.decode(generated_custom)

print(sequence_custom)`
Please help me correct this

Please use this code for the text generation. make sure you have transformer and torch is installed.

pip install transformers
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
generated = tokenizer.encode("The Manhattan bridge", return_tensors='pt')
output_sequences = model.generate(
    input_ids=generated,
    max_length=150,  
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    pad_token_id=tokenizer.eos_token_id
)

text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
print(text)

Note: You can fine tune the parameters for the best answer specially temperature.

Thanks for the prompt reply and help let me check

Its working fine and thanks for help. Although results are not promising.