lstm sampling crashes for small temperatures
genekogan opened this issue · comments
just made some updates to the rnn guide, so pull first.
first, a quick fix:
if seed is None and len(seed) < max_len:
raise Exception('Seed text must be at least {} chars long'.format(max_len))
should actually be the following if i'm not mistaken... (i've changed it)
if seed is not None and len(seed) < max_len:
raise Exception('Seed text must be at least {} chars long'.format(max_len))
then the rest of the guide runs fine, but the sampling function crashes for me when temperature < 1.0 with the following trace. think somehow the probabilities are not summing to 1? maybe an underflow issue here.
ValueError Traceback (most recent call last)
in ()
8 for temp in [1.0]: #[0.2, 0.5, 1., 1.2]:
9 print('\n\ttemperature:', temp)
---> 10 print(generate(temperature=temp))
in generate(temperature, seed, predicate)
22
23 # sample the character to use based on the predicted probabilities
---> 24 next_idx = sample(probs, temperature)
25 next_char = labels_char[next_idx]
26
in sample(probs, temperature)
33 a = np.log(probs)/temperature
34 a = np.exp(a)/np.sum(np.exp(a))
---> 35 return np.argmax(np.random.multinomial(1, a, 1))
mtrand.pyx in mtrand.RandomState.multinomial (numpy/random/mtrand/mtrand.c:32793)()
ValueError: sum(pvals[:-1]) > 1.0
interesting...you're using a temperature of 1.0 there? can you print out the probabilities too? i'll take a closer look.
oh oops, i switched it to just [1.0] but it was more problematic with smaller values e.g. 0.2.
it looks like the probabilities sometimes sum to just above 1.0...I think I have a workaround, testing it now.