chr5tphr / zennit

Zennit is a high-level framework in Python using PyTorch for explaining/exploring neural networks using attribution methods like LRP.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

support unet with Upsample or ConvTranspose2d layer?

bugsuse opened this issue · comments

I want to use LRP to explain the semantic segmentation task using Unet model (Pytorch). I tested the LRP in captum but not support nn.Upsample and nn.ConvTranspose2d. I would like to know if the semantic segmentation model like Unet can be supported, and if not, how should it be implemented? Any help would be appreciated!

Hey @bugsuse ,

I think this should simply just work if you use any of the composite rules.

Since the gradient of nn.Upsample will just be constant depending on the size upsampled, this will probably scale the full attribution by a little, but this should not be a problem. nn.ConvTranspose2d is simply a linear layer, and therefore is supported.

If you simply use one of the Composites like EpsilonGammaBox or EpsilonAlpha2Beta1Flat (or anything, really), this should just work out of the box, though I did not try it with UNet specifically yet. Unless you are using BatchNorm, you also do not have to supply a Canonizer. Have a look at the example.

Just in case the gradients get too large with the nn.Upsample, you can use the Norm rule and build your own composite:

import torch

from zennit.rules import Gamma, Epsilon, ZBox, Norm
from zennit.types import Convolution
from zennit.composites import SpecialFirstLayerMapComposite, LAYER_MAP_BASE


class UpsampledEpsilonGammaBox(SpecialFirstLayerMapComposite):
    '''An explicit composite using the ZBox rule for the first convolutional layer, gamma rule for all following
    convolutional layers, and the epsilon rule for all fully connected layers.
    Additionally, this uses the `Norm` rule for `nn.Upsample`.

    Parameters
    ----------
    low: obj:`torch.Tensor`
        A tensor with the same size as the input, describing the lowest possible pixel values.
    high: obj:`torch.Tensor`
        A tensor with the same size as the input, describing the highest possible pixel values.
    '''
    def __init__(self, low, high, canonizers=None):
        layer_map = LAYER_MAP_BASE + [
            (Convolution, Gamma(gamma=0.25)),
            (torch.nn.Linear, Epsilon()),
            (torch.nn.Upsample, Norm()),
        ]
        first_map = [
            (Convolution, ZBox(low, high))
        ]
        super().__init__(layer_map, first_map, canonizers=canonizers)

But I would just try it with the gradient first, i.e., use one of the built-in composites.

Hey @bugsuse,

I have already used Zennit for a UNet with the nn.Upsample layer and as @chr5tphr said, this worked even when just using the gradient (no extra rule). In fact, in my case it did not make a difference whether I used the gradient, Norm() or Epsilon() rule for nn.Upsample.

@chr5tphr Thanks sooo much! I will try it!

@maxdreyer It's great! Could you have a relevant notebook or example blog? Is it possible to share with me?

@bugsuse,

you can basically use the code that is written in zennit/share/example/feed_forward.py. After loading your model and data, you only have to adapt the choice of the output_relevance (the output that is propagated backwards).

For classification tasks, it makes sense to propagate zeros everywhere but the index of the target class:
output_relevance = torch.eye(n_outputs, device=device)[target]

For segmentation, it depends on what you are interested in. In my case, I tried to propagate backwards the output channel of a specific class like

output = model(input)
output_relevance = torch.zeros_like(output)
output_relevance[:, class_index, :, :] = output[:, class_index, :, :]

@maxdreyer I tested it according to feed_forwadr.py but raised RuntimeError,

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-44-740b8083f87f> in <module>
     45 
     46         # this will compute the modified gradient of model, with the on
---> 47         output, relevance = attributor(data.cuda(), output_relevance.cuda())
     48 
     49         # sum over the color channel for visualization

~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/zennit/attribution.py in __call__(self, input, attr_output)
    130 
    131         if self.composite is None or self.composite.handles:
--> 132             return self.forward(input, attr_output_fn)
    133 
    134         with self:

~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/zennit/attribution.py in forward(self, input, attr_output_fn)
    175         input = input.detach().requires_grad_(True)
    176         output = self.model(input)
--> 177         gradient, = torch.autograd.grad((output,), (input,), grad_outputs=(attr_output_fn(output.detach()),))
    178         return output, gradient
    179 

~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/autograd/__init__.py in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused)
    200         retain_graph = create_graph
    201 
--> 202     return Variable._execution_engine.run_backward(
    203         outputs, grad_outputs_, retain_graph, create_graph,
    204         inputs, allow_unused)

~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/zennit/core.py in wrapper(grad_input, grad_output)
    139         @functools.wraps(self.backward)
    140         def wrapper(grad_input, grad_output):
--> 141             return hook_ref().backward(module, grad_input, hook_ref().stored_tensors['grad_output'])
    142 
    143         if not isinstance(input, tuple):

~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/zennit/core.py in backward(self, module, grad_input, grad_output)
    279             input = in_mod(original_input).requires_grad_()
    280             with mod_params(module, param_mod, **param_kwargs) as modified, torch.autograd.enable_grad():
--> 281                 output = modified.forward(input)
    282                 output = out_mod(output)
    283             inputs.append(input)

~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/conv.py in forward(self, input)
    421 
    422     def forward(self, input: Tensor) -> Tensor:
--> 423         return self._conv_forward(input, self.weight)
    424 
    425 class Conv3d(_ConvNd):

~/tools/miniconda3/envs/pytorch/lib/python3.8/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight)
    417                             weight, self.bias, self.stride,
    418                             _pair(0), self.dilation, self.groups)
--> 419         return F.conv2d(input, weight, self.bias, self.stride,
    420                         self.padding, self.dilation, self.groups)
    421 

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [16, 24, 3, 3], but got 3-dimensional input of size [16, 64, 64] instead

the example code is as follow,

device = 'cuda:0'
model.to(device)
model.eval()

# disable requires_grad for all parameters, we do not need their modified gradients
for param in model.parameters():
    param.requires_grad = False

output = model(input.cuda())
output_relevance = torch.zeros_like(output)

# create a composite if composite_name was set, otherwise we do not use a composite
composite = None
if composite_name is not None:
    composite_kwargs = {}
    if composite_name == 'upsample_epsilon_gamma_box':
        # the maximal input shape, needed for the ZBox rule
        shape = (batch_size, 64, 64)

        # the highest and lowest pixel values for the ZBox rule
        composite_kwargs['low'] = torch.zeros(*shape, device=device)
        composite_kwargs['high'] = torch.ones(*shape, device=device)

    # use torchvision specific canonizers, as supplied in the MODELS dict
    composite_kwargs['canonizers'] = [MODELS[model_name][1]()]

    # create a composite specified by a name; the COMPOSITES dict includes all preset composites provided by zennit.
    composite = COMPOSITES[composite_name](**composite_kwargs)

# specify some attributor-specific arguments
attributor_kwargs = {
    'smoothgrad': {'noise_level': 0.1, 'n_iter': 20},
    'integrads': {'n_iter': 20},
    'occlusion': {'window': (56, 56), 'stride': (28, 28)},
}.get(attributor_name, {})

attributor = ATTRIBUTORS[attributor_name](model, composite, **attributor_kwargs)

sample_index = 0

with attributor:
    for data, target in valid_loader:
        output_relevance = torch.zeros_like(torch.squeeze(target))
        output, relevance = attributor(data.cuda(), output_relevance.cuda())

## the rest is the same as `feed_forward.py`
...

Thanks for your kindly help!

I expect you already made sure the model runs when doing a forward pass without Zennit? (I.e., without the Attributor context)
Could you supply the model you are using?
There may be a problem when passing tuples instead of tensors between layers, is this the case for you?

@bugsuse also, you set the output_relevance twice:
in the beginning output_relevance = torch.zeros_like(output)
and in the loop later output_relevance = torch.zeros_like(torch.squeeze(target))

output_relevance should have the same shape as output. torch.zeros_like(torch.squeeze(target)) could have the wrong shape.

@chr5tphr Yeah, I load directly pretrained weights. I'm using Unet model. More codes is below,

class UNetCanonizer(SequentialMergeBatchNorm):
    '''Canonizer for torchvision.models.vgg* type models. This is so far identical to a SequentialMergeBatchNorm'''

MODELS = {
    'vgg16': (vgg16, VGGCanonizer),
    'vgg16_bn': (vgg16_bn, VGGCanonizer),
    'resnet50': (resnet50, ResNetCanonizer),
    'unet': (Unet(), UNetCanonizer)
}

ATTRIBUTORS = {
    'gradient': Gradient,
    'smoothgrad': SmoothGrad,
    'integrads': IntegratedGradients,
    'occlusion': Occlusion,
}


class BatchNormalize:
    def __init__(self, mean, std, device=None):
        self.mean = torch.tensor(mean, device=device)[None, :, None, None]
        self.std = torch.tensor(std, device=device)[None, :, None, None]

    def __call__(self, tensor):
        return (tensor - self.mean) / self.std


class AllowEmptyClassImageFolder(ImageFolder):
    '''Subclass of ImageFolder, which only finds non-empty classes, but with their correct indices given other empty
    classes. This counter-acts the changes in torchvision 0.10.0, in which DatasetFolder does not allow empty classes
    anymore by default. Versions before 0.10.0 do not expose `find_classes`, and thus this change does not change the
    functionality of `ImageFolder` in earlier versions.
    '''
    def find_classes(self, directory):
        with os.scandir(directory) as scanit:
            class_info = sorted((entry.name, len(list(os.scandir(entry.path)))) for entry in scanit if entry.is_dir())
        class_to_idx = {class_name: index for index, (class_name, n_members) in enumerate(class_info) if n_members}
        if not class_to_idx:
            raise FileNotFoundError(f'No non-empty classes found in \'{directory}\'.')
        return list(class_to_idx), class_to_idx

COMPOSITES.update({'upsample_epsilon_gamma_box': UpsampledEpsilonGammaBox})

model = Unet(num_channels_in,  num_channels_out)
model = model.load_from_checkpoint('results/weight/unet_ci-unet-epoch=28-val_loss=101.73.ckpt', 
                                   hparams_file='results/log/unet/version_0/hparams.yaml')

attributor_name = 'gradient'  
composite_name = 'upsample_epsilon_gamma_box'  
model_name = 'unet'  
batch_size = 16      
shuffle = False
relevance_norm = 'symmetric' 
cmap = 'coldnhot'   
level = 1.0    
seed = 21  
cpu = True

# create a composite if composite_name was set, otherwise we do not use a composite
composite = None
if composite_name is not None:
    composite_kwargs = {}
    if composite_name == 'upsample_epsilon_gamma_box':
        # the maximal input shape, needed for the ZBox rule
        shape = (batch_size, 64, 64)

        # the highest and lowest pixel values for the ZBox rule
        composite_kwargs['low'] = torch.zeros(*shape, device=device)
        composite_kwargs['high'] = torch.ones(*shape, device=device)

    # use torchvision specific canonizers, as supplied in the MODELS dict
    composite_kwargs['canonizers'] = [MODELS[model_name][1]()]

    # create a composite specified by a name; the COMPOSITES dict includes all preset composites provided by zennit.
    composite = COMPOSITES[composite_name](**composite_kwargs)

# specify some attributor-specific arguments
attributor_kwargs = {
    'smoothgrad': {'noise_level': 0.1, 'n_iter': 20},
    'integrads': {'n_iter': 20},
    'occlusion': {'window': (56, 56), 'stride': (28, 28)},
}.get(attributor_name, {})

# create an attributor, given the ATTRIBUTORS dict given above. If composite is None, the gradient will not be
# modified for the attribution
attributor = ATTRIBUTORS[attributor_name](model, composite, **attributor_kwargs)

# the current sample index for creating file names
sample_index = 0

# the accuracy
accuracy = 0.

# enter the attributor context outside the data loader loop, such that its canonizers and hooks do not need to be
# registered and removed for each step. This registers the composite (and applies the canonizer) to the model
# within the with-statement
with attributor:
    for data, target in valid_loader:
        # we use data without the normalization applied for visualization, and with the normalization applied as
        # the model input

        output_relevance = torch.zeros_like(torch.squeeze(target))

        # this will compute the modified gradient of model, with the on
        output, relevance = attributor(data.cuda(), output_relevance.cuda())

        # sum over the color channel for visualization
        relevance = np.array(relevance.sum(1).detach().cpu())

        # normalize between 0. and 1. given the specified strategy
        if relevance_norm == 'symmetric':
            # 0-aligned symmetric relevance, negative and positive can be compared, the original 0. becomes 0.5
            amax = np.abs(relevance).max((1, 2), keepdims=True)
            relevance = (relevance + amax) / 2 / amax
        elif relevance_norm == 'absolute':
            # 0-aligned absolute relevance, only the amplitude of relevance matters, the original 0. becomes 0.
            relevance = np.abs(relevance)
            relevance /= relevance.max((1, 2), keepdims=True)
        elif relevance_norm == 'unaligned':
            # do not align, the orignal minimum value becomes 0., the orignal maximum becomes 1.
            rmin = relevance.min((1, 2), keepdims=True)
            rmax = relevance.max((1, 2), keepdims=True)
            relevance = (relevance - rmin) / (rmax - rmin)

        for n in range(len(data)):
            fname = relevance_format.format(sample=sample_index + n)
            # zennit.image.imsave will create an appropriate heatmap given a cmap specification
            imsave(fname, relevance[n], vmin=0., vmax=1., level=level, cmap=cmap)
            if input_format is not None:
                fname = input_format.format(sample=sample_index + n)
                # if there are 3 color channels, imsave will not create a heatmap, but instead save the image with
                # its appropriate colors
                imsave(fname, data[n])
        sample_index += len(data)

        # update the accuracy
        accuracy += (output.argmax(1) == target).sum().detach().cpu().item()

accuracy /= len(dataset)
print(f'Accuracy: {accuracy:.2f}')

There may be a problem when passing tuples instead of tensors between layers, is this the case for you?

Are you saying that I should supply the tensor sizes for each layer of unet model as tuple to attributor?

@maxdreyer yeah, but I'm sure output_relevance has the same shape as output. I think that the error maybe was raised due to middle layers, such as Upsample?

Are you saying that I should supply the tensor sizes for each layer of unet model as tuple to attributor?

No, I was only referring to whether your layers in your UNet only produce single outputs, or if there are multiple ones.

Anyway, beyond the sanity check that torch.zeros_like(torch.squeeze(target)) and output must have the same shape, I can only guess the problem you are having without knowing the precise code of your UNet and shape of your dataset. Try running pdb and see if all the shapes are correct. The problem you are having seems to be related to the batch-dimension getting consumed somewhere.

@chr5tphr The UNet model has single outputs with shape (batch_size, 1, width, height), which width and height are both 64. All of the codes and data have been uploaded colab now, including UNet model and test data. I have tried to check it according to your suggestions. Could you help me to check it? Thanks a lot!

@chr5tphr @maxdreyer You both are right! The UNet model requires 5-D input and return 3-D output result in the problem above. I have fixed it by changing the input and output both from 5-D to 4-D, but the results are so strange (as shown below). The relevance of the input channels does not seem to be consistent with the model output.

test

Could this be influenced by BatchNorm layer, or anything else?

Do you only propagate zeroes back? You should try it the way @maxdreyer suggested, but since you only have a single class, I would try to simply use the full output (or ground_truth), i.e.

output_relevance = target.to(device)
# or
output_relevance = model(data.to(device))

Otherwise, if this still produces unexpected results, maybe try epsilon_alpha2_beta1_flat or epsilon_plus, see how those behave.

@chr5tphr The result is that according to @maxdreyer suggested. I will try it epsilon_alpha2_beta1_flat or epsilon_plus rule. Thanks a lot!

@maxdreyer May I ask what rules you are using? Is the result what is expected?

@bugsuse using the rule epsilon_plus_flat I received a heatmap that was similar to the output mask.

@maxdreyer Thanks so much! I will test it!

I tried it using epsilon_plus_flat rule, but the result is strange. The relevance always seems to be at the center and does not correspond to the input. What could be the cause of it?

I do not really know anything about the data, but this could potentially mean the prediction for the positive mask values is very local. Alternatively you could try to set all pixels of the output_relevance to 1 for the target channel.

Hi,
@bugsuse I see that you use the SequentialMergeBatchNorm canonizer. As far as I understand, this canonizer is only used for BatchNorm Layers. Does your UNet implementation contain torch.add layers ?

@chr5tphr for torch.cat we do not need a canonizer, right?

Best

@rachtibat Have a look at the supplied code in colab, the UNet implemementation includes BatchNorms, and the SequentialMergeBatchNorm is applied correctly (the BatchNorms in DoubleConv are assigned in the same order they are called, thus they will be detected correctly).

No, concatenation does not need a canonizer.

But I just noticed, that in the Down Module

class Down(nn.Module):
    '''Downscaling with conv(stride = 2) then double conv'''
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.pool_conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=2, padding=1),
            DoubleConv(in_channels, out_channels)
        )

there's two subsequent linear modules (three with the BatchNorm, at the beginning of DoubleConv), which technically would need to be merged to have a canonical form independent of the implementation, similarly to the BatchNorm.
However, I am not sure whether this would have a significant effect.
Another alternative would be to use the epsilon rule in those layers.
@maxdreyer Can you check your UNet architecture and see whether you have something similar?

@chr5tphr I'd like to predict future semantic segmentation using past multiple frames multiple channels satellite observation.

I tested it to set all pixels of the output_relevance to 1 for the target channel, but the result has without significant difference.

Here is the complete code, pretrained weights, and test data. Hope it helps!

@chr5tphr You are right, that might be a problem! In my implementation, a nn.MaxPool2d(2) is used instead of nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=2, padding=1).

@bugsuse thus, another option would also be to replace the Conv2d layer with a MaxPool2d layer and do retraining. If this is too much work, one could of course also merge the two subsequent linear layers or try different rules as @chr5tphr has proposed.

@maxdreyer Thanks a lot! I will test it!

@bugsuse Hi, maybe you can try out the new pull request #45
with git fetch origin pull/45/head:YOUR_BRANCH_NAME

This produces much better heatmaps with LRP

Version 0.3.3 is live on pypi, you can just update to check.

@chr5tphr @rachtibat Cool! I will try it!