google / gemma_pytorch

The official PyTorch implementation of Google's Gemma models

Home Page:https://ai.google.dev/gemma

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Can't disable sampling

joselpart opened this issue · comments

commented

Currently, the generate() method doesn't seem to allow disabling sampling. The forward() method in the Sampler class performs greedy search if the temperatures argument is None but the GemmaForCausalLM's generate() method doesn't allow for setting the temperature argument to None because of this line -> https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L508. Also, setting the temperature to 0 fails with the following error RuntimeError: probability tensor contains either inf, nan or element < 0.