luofuli / DualRL

A Dual Reinforcement Learning Framework for Unsupervised Text Style Transfer (IJCAI 2019)

Home Page:https://export.arxiv.org/pdf/1905.10060

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question about the pseudo-parallel data in the DualRL training stage

sunny371 opened this issue · comments

Thanks for your great work!

I have a question for the pseudo-parallel data used in the annealing pseudo teacher-forcing stage: From your paper, I notice the pseudo-parallel data is generated on-the-fly using the latest model. However, I find the codes for this in the dual_training.py, as shown below,

DualRL/dual_training.py

Lines 84 to 86 in 7983ec0

paired_src_train_iterator = load_paired_dataset(args.tsf_train_data[B], args.train_data[B],
src_vocab, tgt_vocab, batch_size=args.batch_size,
min_seq_len=min_seq_len)

DualRL/dual_training.py

Lines 308 to 317 in 7983ec0

if n_batch % gap == 0:
data = sess.run(paired_train_data_next[A]) # get real data!!
feed_dict = {
nmts_train[A].input_ids: data["ids"],
nmts_train[A].input_length: data["length"],
nmts_train[A].target_ids_in: data["trans_ids_in"],
nmts_train[A].target_ids_out: data["trans_ids_out"],
nmts_train[A].target_length: data["trans_length"],
}
nmtA_pseudo_loss_, _ = sess.run([nmts_train[A].loss, nmts_train[A].train_op], feed_dict=feed_dict)

just load paired data from args.tsf_train_data which stores the pseudo-parallel data generated by the template-based approach for pretraining, instead of generating with the latest model.

Did I understand this correctly? Have I missed anything?
Thank you.

Uh oh! I made a mistake here when I cleaned the code. I have deleted the code which uses the pseudo-data generated on-the-fly but preserved the code which uses the original pseudo-parallel data.

This figure shows the original code (before cleaning):
image

Wait for a minute, I will update the code! Thanks!

Thank you very much for your reply @luofuli .
But It seems there is still a problem. The mid_ids_in_bs is the transferred output from source style A to the target style B. Shouldn't this be used as the input_ids together with src[''ids''] as targets to train nmts_train[B] instead of nmts_train[A], since the generated text should be put in the source side?

Yes, you are right! I did not show you the whole code! The section of previous_b is right with you! I just do several ablation studies to validate which one can work! And I found out that previous_b works!

image

Got it! Thanks a lot! @luofuli

I have updated the code! If you don't have any questions, you can close the issue! Thanks for your feedback!