Likely bug: log_prob is not affected by sampling_temperature
YuchenLi01 opened this issue · comments
Context
In language model generation, we use the hyperparameter sampling_temperature
to adjust the probability distribution of predicting the next token. A smaller sampling_temperature
sharpens the distribution, whereas a larger sampling_temperature
makes it closer to a uniform distribution. A very brief interactive blog post on sampling_temperature
is at https://lukesalamone.github.io/posts/what-is-temperature/.
Bug
Consequently, it is natural to expect that sampling_temperature
makes a difference in model-predicted log_prob
of the next token. However, it seems that in the CTranslate2 implementation, varying sampling_temperature
cannot change the returned log_prob
.
How to reproduce
Assuming generator
and tokenizer
are properly loaded. Run the following (basically: starting from the same prompt, only sample one token, try with two different temperatures):
prompt = "Test:"
prompt_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt, truncation=False))
print('prompt:', prompt_tokens)
print('sampling_temperature = 0.8')
step_results = generator.generate_tokens(
prompt_tokens,
sampling_temperature=0.8,
sampling_topk=1,
max_length=1,
return_log_prob=True,
)
for step_result in step_results:
print('generated token', step_result.token)
print('log_prob', step_result.log_prob)
print('sampling_temperature = 0.1')
step_results = generator.generate_tokens(
prompt_tokens,
sampling_temperature=0.1,
sampling_topk=1,
max_length=1,
return_log_prob=True,
)
for step_result in step_results:
print('generated token', step_result.token)
print('log_prob', step_result.log_prob)
Output I got:
prompt: ['<s>', '▁Test', ':']
sampling_temperature = 0.8
generated token ▁
log_prob -2.76171875
sampling_temperature = 0.1
generated token ▁
log_prob -2.76171875
What's wrong: although the temperatures are different, the returned log_prob
did not change.
You are using sampling_topk = 1
, in this case, random sampler
is used and we don't use sampling_temperature
to randomize the sample (best sampler
is affected by sampling_temperature
instead). Try to increase the sampling_topk