about "JIT multiple training steps together"
ShiZiqiang opened this issue · comments
Hello, Dr. Song
Thank you for sharing this excellent work.
I saw that a parameter "n_jitted_steps" was used in the training, and the comment of the code said: "JIT multiple training steps together for faster training." Can you explain why and how to conduct this "JIT multiple training steps together"? Does this "n_jitted_steps" affect performance, that is, if I don't use this "JIT multiple training steps together", will the performance be the same?
Thank you in advance.
This n_jitted_steps
doesn't affect sample quality or likelihoods. No matter what n_jitted_steps
you set, you are running exactly the same training procedure. Specifically, you are jit-compiling multiple training steps to execute them together on GPUs/TPUs, and the number of training steps to jit together is given by n_jitted_steps
. A larger n_jitted_steps
can make training faster at the cost of more memory usage.
Hi, Dr. Song, Thank you so much for the clear explanation. Totally understood.