MushroomRL / mushroom-rl

Python library for Reinforcement Learning.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question: TorchApproximator.predict - Why no torch.no_grad() and why call forward directly?

VanillaWhey opened this issue · comments

Is your feature request related to a problem? Please describe.
I was wondering why the predict method in class TorchApproximator calculates the gradients and calls self.network.forward(*torch_args, **kwargs).

Describe the solution you'd like
Why not use the with torch.no_grad() statement to save memory on the one hand and to omit the detach() call on the other hand. Further, calling self.network(*torch_args, **kwargs) instead of forward is better practice (if there's no good reason for doing otherwise).

Describe alternatives you've considered
If it is the desired behaviour, that the output_tensor flag means a tensor with gradients should be returned, the first point is void.

Additional context
Additionally, in line 254 in the same file, the .requires_grad_(False) has no effect since the .detach() call has already taken care of it.

about the torch.no_grad: we cannot use it in forward, as it is also used when computing the losses in actor-critic algorithms.
I don't remember if there was a specific reason for using forward instead of the function call. It may be that we had some issues years ago, but I don't remember any reason. I don't have time to check this now.

output_tensor flag means a tensor with gradients should be returned, the first point is void. -> yes, it means to return a tensor with the gradient.

.requires_grad_(False) has no effect since the .detach() call has already taken care of it. -> could possibly be, but I'm not quite sure, and I have no time to check it now. We did this quite a long time ago, and there was a reasoning behind that.