TorchModel move(value) recursive calls
AlexWan0 opened this issue · comments
For a data_item
many non-tensor nested lists, the move
function recursively calls itself many times, with each call taking a long time for models with large numbers of parameters (e.g. ones with lots of learners).
The move
function calls itself 70615 times with this data_item
. This number of calls is made regardless of the number of learners there are, but increasing the number of learners increases the time it takes to call
parameters = list(self.parameters())
which is used to find the device to move the tensor to.
Each call of this line takes 4.220008850097656e-05 for 10 learners and 0.005646944046020508 for 1076 learners.
Since the only purpose of this call is to find the correct device to use, it may be better to define the device once, on initialization of the TorchModel
, then use that value in subsequent calls.
i.e. adding
parameters = list(self.parameters())
self.device = next(self.parameters()).device
to __init__
in TorchModel
and
using self.device
, instead of redefining it every time move
is called.
Overall, the move
call took 394 seconds, and one entire train iteration took 397 seconds (with 1076 learners). With the fix, move took 0.10 seconds, and a training iteration took 3.30 seconds.
@guoquan Do you have any idea about this issue? ( I am asking this because you implemented the device features, just want to make sure that there is nothing here that we are missing from your perspective)
Did not expected to receive tensors as the large nested lists in the case.
The reason there isn't TorchModel.device
is, we want to support running over multiple devices. Sensors can be placed on different devices (and there is Sensor.device
).
If we simply want to avoid repeated calls to parameters = list(self.parameters())
, maybe we can allow the argument device
to take the priority. Something like
def move(self, value, device=None):
if device is None:
try:
device = next(self.parameters()).device
except StopIteration: # no parameters
pass
# ... some other cases
if isinstance(value, list):
return [self.move(v, device=device) for v in value]
# ... some other cases
Because there is device
in recursive calls self.move(v, device=device)
, it should avoid retrieving the device again.
Alright, I will do that.
Yeah, the fix works. I've changed my code slightly since then, so the entire move operation takes 240.6984 seconds total without the fix, and 0.869 seconds with the fix.
Yeah, the fix works. I've changed my code slightly since then, so the entire move operation takes 240.6984 seconds total without the fix, and 0.869 seconds with the fix.
@AlexWan0 could you make a branch just by isolating that one fix and make a pull request to the main branch?