HLR / DomiKnowS

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TorchModel move(value) recursive calls

AlexWan0 opened this issue · comments

For a data_item many non-tensor nested lists, the move function recursively calls itself many times, with each call taking a long time for models with large numbers of parameters (e.g. ones with lots of learners).

data_item used, for reference

The move function calls itself 70615 times with this data_item. This number of calls is made regardless of the number of learners there are, but increasing the number of learners increases the time it takes to call
parameters = list(self.parameters())
which is used to find the device to move the tensor to.
Each call of this line takes 4.220008850097656e-05 for 10 learners and 0.005646944046020508 for 1076 learners.

Since the only purpose of this call is to find the correct device to use, it may be better to define the device once, on initialization of the TorchModel, then use that value in subsequent calls.

i.e. adding

parameters = list(self.parameters())
self.device = next(self.parameters()).device

to __init__ in TorchModel and

using self.device, instead of redefining it every time move is called.

Overall, the move call took 394 seconds, and one entire train iteration took 397 seconds (with 1076 learners). With the fix, move took 0.10 seconds, and a training iteration took 3.30 seconds.

@guoquan Do you have any idea about this issue? ( I am asking this because you implemented the device features, just want to make sure that there is nothing here that we are missing from your perspective)

Did not expected to receive tensors as the large nested lists in the case.
The reason there isn't TorchModel.device is, we want to support running over multiple devices. Sensors can be placed on different devices (and there is Sensor.device).

If we simply want to avoid repeated calls to parameters = list(self.parameters()), maybe we can allow the argument device to take the priority. Something like

    def move(self, value, device=None):
        if device is None:
            try:
                device = next(self.parameters()).device
            except StopIteration:  # no parameters
                pass
        # ... some other cases
        if isinstance(value, list):
            return [self.move(v, device=device) for v in value]
        # ... some other cases

Because there is device in recursive calls self.move(v, device=device), it should avoid retrieving the device again.

@AlexWan0 can you please try the fix that @guoquan mentioned above and report back the difference in the execution time.

Alright, I will do that.

Yeah, the fix works. I've changed my code slightly since then, so the entire move operation takes 240.6984 seconds total without the fix, and 0.869 seconds with the fix.

Yeah, the fix works. I've changed my code slightly since then, so the entire move operation takes 240.6984 seconds total without the fix, and 0.869 seconds with the fix.

@AlexWan0 could you make a branch just by isolating that one fix and make a pull request to the main branch?