Use no_sync when doing gradient accumulation
achalddave opened this issue · comments
By default, FSDP will reduce gradients on every backward() call, which is slow in multi node settings. We should use fsdp.no_sync() to only reduce gradients on the last backward call.