Mixing teacher forcing with "feed previous"
suriyadeepan opened this issue · comments
As you mentioned in the start of the 2nd tutorial, it is good idea to mix teacher forcing with "feed previous" technique, while decoding. Just thought I could share some ideas on how to do that.
prob = 0.5 # set as placeholder or tf.constant
r = tf.random_normal(shape=[],mean=prob, stddev=0.5, dtype=tf.float32) # get a random value
feed_previous = r > prob # sample -> True/False
In the loop_fn_transition
function, you could add an outer condition like this.
if feed_previous:
input = tf.cond(finished, padded_next_input, search_for_next_input)
else:
input = tf.cond(finished, padded_next_input, fetch_next_decoder_target)
The fetch_next_decoder_target
function is supposed to fetch the next decoder target by indexing decoder_targets
with time
- decoder_targets[time]
. Though you need to transpose decoder_targets
to "time major" format.
Hope this helps. I will try this and add a pull request if I find time.
How might I go about mixing teacher forcing with feed previous with code from tutorial #3 ?