Stonesjtu / pytorch_memlab

Profiling and inspecting memory in pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Jupyter notebook support

willprice opened this issue · comments

Hi,
Thanks for the super useful package. Currently, it seems it is not possible to leverage it under a jupyter notebook. If I create a new notebook and add a cell with the contents

import torch
from pytorch_memlab import profile
@profile
def work():
    linear = torch.nn.Linear(100, 100).cuda()
    linear2 = torch.nn.Linear(100, 100).cuda()
    linear3 = torch.nn.Linear(100, 100).cuda()

work()

I get no results printed.

However, when I use MemReporter I do get results:

import torch
from pytorch_memlab import  MemReporterrter

linear = torch.nn.Linear(1024, 1024).cuda()
reporter = MemReporter()
reporter.report()
Element type                                            Size  Used MEM
-------------------------------------------------------------------------------
Storage on cuda:0
Parameter0                                      (1024, 1024)     4.00M
Parameter1                                           (1024,)     4.00K
Parameter2                                      (1024, 1024)     4.00M
Parameter3                                           (1024,)     4.00K
-------------------------------------------------------------------------------
Total Tensors: 2099200 	Used Memory: 8.01M
The allocated memory on cuda:0: 8.01M
-------------------------------------------------------------------------------

Additionally, I wonder if it is possible to add a line magic similar to %lprun for profiling cells.

I guess simply adding a line magic won't help. It's probably because currently the usage information is only printed at python process termination for @profile.

One perfect solution is to find if jupyter notebook gives a symbol when a cell is finished.

However, you can work around using @profile_every() to print the memory usage every N times calls.

Also you may manually print the results (It's quite trick btw) by

from pytorch.memlab import global_line_profiler
global_line_profiler.print_stats()

Resolved in #8