Can't disable sampling
joselpart opened this issue · comments
JP commented
Currently, the generate()
method doesn't seem to allow disabling sampling. The forward()
method in the Sampler
class performs greedy search if the temperatures
argument is None
but the GemmaForCausalLM's generate()
method doesn't allow for setting the temperature
argument to None
because of this line -> https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L508. Also, setting the temperature
to 0
fails with the following error RuntimeError: probability tensor contains either inf, nan or element < 0
.