google-deepmind / acme

A library of reinforcement learning components and agents

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question: Performance of multi-device prefetching

EdanToledo opened this issue · comments

Hello,

I have a quick question with regards to the implementation of the multi-device put and prefetching iterator. From looking at the code, it seems that there is quite a lot of unnecessary computation that occurs that can inhibit multi device learning actually speeding up computation. From my understanding, what happens is that for each device, the iterator calls next so in the case of 8 devices 8 next calls are done sequentially. Would it not be much faster to simply make a single call and split it. In some experiments I have ran, the speed of the multi-device put is massively slower and upon inspecting the profiler, it's all due to the synchronous calls to next. A solution to this is using prefetching and raising the prefetching thread count but it's still slower than performing the same code on a single device. I was just wondering if there is any specific reason you have opted for sequential next calls instead of a single next call and then data split?

For some of the agent implementations, there's an optional num_sgd_steps_per_step (e.g., https://github.com/deepmind/acme/blob/master/acme/agents/jax/sac/learning.py#L64) which would allow you to run multiple SGD steps per batch (by splitting the batch). In that case, if might be possible to increase the batch size by 8 and also change num_sgd_steps_per_step to 8 to compensate that, which would mean a single call to the underlying iterator with a larger batch size. Would that be useful?

Yeah, that would work better and thanks ill look into that for my specific use case but I don't think the default behaviour should be the multiple next calls - that seems like quite a bottle neck. For example, reverb next calls can already be quite slow compared to the actual SGD step and this is further amplified when using multiple TPU cores since next is called sequentially for each core.