ematvey / tensorflow-seq2seq-tutorials

Dynamic seq2seq in TensorFlow, step by step

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Mixing teacher forcing with "feed previous"

suriyadeepan opened this issue · comments

As you mentioned in the start of the 2nd tutorial, it is good idea to mix teacher forcing with "feed previous" technique, while decoding. Just thought I could share some ideas on how to do that.

prob = 0.5 # set as placeholder or tf.constant
r = tf.random_normal(shape=[],mean=prob, stddev=0.5, dtype=tf.float32) # get a random value
feed_previous = r > prob # sample -> True/False

In the loop_fn_transition function, you could add an outer condition like this.

if feed_previous:
  input = tf.cond(finished, padded_next_input, search_for_next_input)
else:
  input = tf.cond(finished, padded_next_input, fetch_next_decoder_target)

The fetch_next_decoder_target function is supposed to fetch the next decoder target by indexing decoder_targets with time - decoder_targets[time]. Though you need to transpose decoder_targets to "time major" format.

Hope this helps. I will try this and add a pull request if I find time.

How might I go about mixing teacher forcing with feed previous with code from tutorial #3 ?