chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Memory Leak

rachtibat opened this issue · comments

Hi Chris,

unfortunately, there is a memory leak when you run zennit several times in a row.
You can reproduce the results on a GPU using:



from torchvision.models.vgg import vgg16
import torch
import zennit
from zennit.composites import COMPOSITES

model = vgg16(pretrained=True).to("cuda")
model.eval()

data = torch.randn((1,3,224,224)).to("cuda")
data.requires_grad = True
target = 0

eye = torch.eye(1000, device="cuda")
output_relevance = eye[[target]]

for i in range(100):
    print(i)

    try:
        composite = COMPOSITES["epsilon_plus_flat"]()
    except:
        raise ValueError(f"Method not defined. Available are {list(COMPOSITES.keys())}")

    with composite.context(model) as modified:
        # one-hot tensor of target

        pred = modified(data)

        torch.autograd.backward((pred,), (output_relevance,))

    # the attribution will be stored in the gradient's place
    heatmaps = data.grad.sum(1).detach().cpu().numpy()

    r = torch.cuda.memory_reserved(0)
    a = torch.cuda.memory_allocated(0)
    print("memory reserved ", r, "memory allocated ", a)

Do you suspect where the leak is coming from?

I think I have found the problem, check out #13 .
It seems like a reference for LinearHook.input seems to retain somewhere for some reason (not of grad_output though, which makes this something of a mystery).
I handle references to tensors now very carefully, and remove them explicitly after removing the hook (which also decrease the refcount for its attributes). I thought the hook itself might be retained somewhere, but this should then also influence LinearHook.grad_output, which seems to have no problem.

Can you check whether the problem is fixed for you in the PR?
Here's also the snippet I used for debugging the problem:

from torchvision.models.vgg import vgg16
import os
import tracemalloc
import gc

import psutil
import torch
from torchvision.models.vgg import vgg16
from zennit.composites import COMPOSITES
from zennit.rules import ZPlus


tracemalloc.start(25)

# objects = {}

for _ in range(20):
    # model = torch.nn.Linear(1 << 15, 1000)
    model = vgg16()
    model.eval()

    for param in model.parameters():
        param.requires_grad = False

    # data = torch.randn((1, 1 << 15))
    data = torch.randn((1, 3, 224, 224))
    data.requires_grad = True
    target = 0

    eye = torch.eye(1000)
    output_relevance = eye[[target]]

    composite = COMPOSITES["epsilon_plus_flat"]()
    # hook = ZPlus()
    # handles = [
    #     model.register_forward_pre_hook(hook.pre_forward),
    #     model.register_forward_hook(hook.forward),
    #     model.register_backward_hook(hook.pre_backward),
    # ]

    composite.register(model)
    pred = model(data)
    torch.autograd.backward((pred,), (output_relevance,))
    composite.remove()
    # for handle in handles:
    #     handle.remove()

    # del model
    # del composite
    # del hook.input
    # del hook.grad_output
    # del hook
    # del pred
    # del data
    # del target
    # del eye
    # del output_relevance
    gc.collect()

    size, peak = tracemalloc.get_traced_memory()
    print(f"{size=}, {peak=}")

    process = psutil.Process(os.getpid())
    print("psutil: ", process.memory_info().rss)  # in bytes

    print('gc stats:', gc.get_stats())

    print('current objects:', len(gc.get_objects()))
    # newobjs = {id(obj): obj for obj in gc.get_objects()}
    # diff = {key: newobjs[key] for key in newobjs if key not in objects}

    # print('new objects:', len(diff))
    # print([(key, val.shape) for key, val in diff.items() if isinstance(val, torch.Tensor)])

    # objects.update(newobjs)
    print()

# stats = tracemalloc.take_snapshot().statistics('traceback')
# stat = stats[0]
# print("%s memory blocks: %.1f KiB" % (stat.count, stat.size / 1024))
# print('\n'.join(stat.traceback.format()))

Hey,

this is awesome! Now it is working smooth.

Best