OpenNMT / CTranslate2

Fast inference engine for Transformer models

Home Page:https://opennmt.net/CTranslate2

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Likely bug: log_prob is not affected by sampling_temperature

YuchenLi01 opened this issue · comments

Context
In language model generation, we use the hyperparameter sampling_temperature to adjust the probability distribution of predicting the next token. A smaller sampling_temperature sharpens the distribution, whereas a larger sampling_temperature makes it closer to a uniform distribution. A very brief interactive blog post on sampling_temperature is at https://lukesalamone.github.io/posts/what-is-temperature/.

Bug
Consequently, it is natural to expect that sampling_temperature makes a difference in model-predicted log_prob of the next token. However, it seems that in the CTranslate2 implementation, varying sampling_temperature cannot change the returned log_prob.

How to reproduce
Assuming generator and tokenizer are properly loaded. Run the following (basically: starting from the same prompt, only sample one token, try with two different temperatures):

prompt = "Test:"
prompt_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt, truncation=False))
print('prompt:', prompt_tokens)

print('sampling_temperature = 0.8')

step_results = generator.generate_tokens(
    prompt_tokens,
    sampling_temperature=0.8,
    sampling_topk=1,
    max_length=1,
    return_log_prob=True,
)

for step_result in step_results:
    print('generated token', step_result.token)
    print('log_prob', step_result.log_prob)


print('sampling_temperature = 0.1')

step_results = generator.generate_tokens(
    prompt_tokens,
    sampling_temperature=0.1,
    sampling_topk=1,
    max_length=1,
    return_log_prob=True,
)

for step_result in step_results:
    print('generated token', step_result.token)
    print('log_prob', step_result.log_prob)

Output I got:

prompt: ['<s>', '▁Test', ':']
sampling_temperature = 0.8
generated token ▁
log_prob -2.76171875
sampling_temperature = 0.1
generated token ▁
log_prob -2.76171875

What's wrong: although the temperatures are different, the returned log_prob did not change.

You are using sampling_topk = 1, in this case, random sampler is used and we don't use sampling_temperature to randomize the sample (best sampler is affected by sampling_temperature instead). Try to increase the sampling_topk