ematvey / tensorflow-seq2seq-tutorials

Dynamic seq2seq in TensorFlow, step by step

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Tutorial 1: decoder is fed correct values during prediction

opened this issue · comments

As far as I can tell, the decoder in tutorial 1 is always fed the true values, not just during training but also for inference.

decoder_logits, decoder_final_state = tf.nn.dynamic_rnn(
    decoder_cell, decoder_inputs_onehot,
    
    initial_state=encoder_final_state,

    dtype=tf.float32, time_major=True, scope="plain_decoder",
)

For training, it makes sense that the inputs are decoder_inputs_onehot. However, for inference, the inputs should be the decoder's prediction from the previous timestep.

Am I misunderstanding something, or is this a bug?

Actually, tutorial 1 only illustrates training. Inference with tf.nn.dynamic_rnn decoder is tricky. All I can think of is defining inference dynamic_rnn decoder that shares weights with training decoder, but takes as input a [1, batch_size] int32 tensor. That is, one timestep of a minibatch. And then we unroll decoder manually in for loop for every timestep, calling session.run for each step.

This is totally impractical with TF. It might be a nice illustration, but I think it more of a tutorial 2/raw_rnn scope. Might worth focusing attention on it in the text though. What do you think?

The two-graph approach is used in raindeer/seq2seq_experiments, though they don't use dynamic_rnn. I think you're right about it being closer to tutorial 2's scope, but I'm not sure if it's worth putting a lot of energy into now that the seq2seq API is about to change (again!). However, it's probably worth noting that inference is difficult in the text of tutorial 1 - it might save someone a bit of time.

Any info on if tf.nn.raw_rnn would change? I find new TF's seq2seq API rather cumbersome, hiding important details while not providing much simplifaction over tf.nn.raw_rnn. Perhaps using tf.nn.raw_rnn directly is a better way.

I added a bit that explains that inference is not possible with tf.nn.dynamic_rnn. Closing now.

Thank you! Unfortunately I don't know if raw_rnn is changing - I haven't heard anything about that happening, but I'm not very in the know.