Question regarding training speed
LeoXinhaoLee opened this issue · comments
Hi, thank you so much for releasing code for this inspiring work.
When I adopt distributed training on a TPU v3-128 pod for imagenet classification, with a global batch size of 1024, I notice the steps_per_sec
is about 2.92, which is much smaller than the reported 4.85.
To run my code, I simply set IS_LOCAL=False, local_batch_size = 8, num_devices = 128
, and did no further modification, which I believe corresponds to Perceiver I/O 2D FF. The reported steps per sec is 4.85 (Perceiver I/O paper table 8 row 2).
Theoretically, as we run on v3-128 pod, we should get a x2 speed up, which is 9.7, but only got 2.92 now.
Thank you very much for your time and help!