[Question] Buffer Priorities Implementation
LucasAlegre opened this issue · comments
Hi!
I noticed that the buffer priorities here are stored in a big torch tensor, and the sampling is done via torch.searchsorted()
. This is in contrast with the original LAP code (https://github.com/sfujim/LAP-PAL/blob/master/continuous/utils.py) which uses a sum tree data structure implemented using numpy.
Have you seen performance improvements with this approach? Is this faster that implementing a sum tree in torch, for instance?
Thanks a lot for providing the code :)
In practice it's marginally faster, but this might change depending on hardware (and replay buffer size). Theoretically, this version is O(n) compared to O(logn) with the sum tree, but since it runs on GPU/PyTorch, it ends up being quicker regardless. Anything over 1M or running on CPU, and the sum tree version is probably faster.
The secondary benefit is that it's very easy to implement compared to the sum tree. Thanks for the question!