TylerYep / torchinfo

View model summaries in PyTorch!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Remove gradient calculation during inferencing (torch.no_grad)

SinChee opened this issue · comments

During inferencing, using torch.no_grad would save up memory space as the grad_fn is None

I have the following codes:

def print_model_intermediate_info(model, input_size = (16, 3, 224, 224)):
    print("Train", summary(model, input_size=input_size, mode="train").total_output_bytes*1e-6)
    print("Eval", summary(model, input_size=input_size, mode='eval').total_output_bytes*1e-6)

    with torch.no_grad():
        print("no grad", summary(model, input_size=input_size, mode='eval').total_output_bytes*1e-6)

model = torchvision.models.resnet18()
print_model_intermediate_info(model)

Output:
Train 57.802752
Eval 57.802752
no grad 57.802752

Would be a great feature to include this use case.

We do use no_grad: https://github.com/TylerYep/torchinfo/blob/main/torchinfo/torchinfo.py#L286

Closing for now, feel free to reply or reopen if this is not what you are asking about.