SHI-Labs / Neighborhood-Attention-Transformer

Neighborhood Attention Transformer, arxiv 2022 / CVPR 2023. Dilated Neighborhood Attention Transformer, arxiv 2022

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Salient map

Yuuuumie opened this issue · comments

Can you share your method or code which you used to draw the salient map? Thanks a lot!

Hi @Yuuuumie, we'll get this integrated into the training script soon, but here's the main function to use for now. The algorithm is fairly simple, we're just looking at the activation of the input data on the network. We add noise as this helps add robustness, but can be seen as an adversarial attack, beneficial to what we want and is typical since we train with heavy augmentations.

import torch
from torchvision.transforms.functional import to_pil_image

# These will need to change depending on the dataset you use. Best practice is to calculate since online
# resources can sometimes provide bad data (often you'll see ImageNet results and not your specific dataset)
IMAGENET_DEFAULT_MEAN = torch.Tensor([0.485, 0.456, 0.406])
IMAGENET_DEFAULT_STD = torch.Tensor([0.229, 0.224, 0.225])

def batch_salient(model,
                  imgs,
                  mean=IMAGENET_DEFAULT_MEAN,
                  std=IMAGENET_DEFAULT_STD,
                  rounds=100,
                  noise_std=0.1,
                  noise_mean=0,
                  ):
    for i in range(rounds + 1):
        noise = torch.randn(imgs.size()) * noise_std + noise_mean
        noise = noise.to(imgs.device)
        imgs.requires_grad_()
        salient = None
        # Book-keeping so we maintain the root image for super-imposing the salient onto the original image
        if i == 0:
            preds = model(imgs)
            preds_orig = preds.clone()
        else:
            preds = model(imgs + noise)
        scores, indices = torch.max(preds, dim=1)
        scores.backward(torch.ones_like(scores))
        if salient is None:
            salient = torch.max(imgs.grad.data, dim=1)[0]
        else:
            salient += torch.max(imgs.grad.data, dim=1)[0]
    # This next line is optional and just normalizes the result. 
    salient /= rounds
    salient.relu_()
    salients = [to_pil_image(s.cpu().squeeze(0)).convert("RGB") for s in salient]
    return preds_orig, salients

I'll close this comment when we add the code to the repo but I hope this is useful for now.

Hi @Yuuuumie, we'll get this integrated into the training script soon, but here's the main function to use for now. The algorithm is fairly simple, we're just looking at the activation of the input data on the network. We add noise as this helps add robustness, but can be seen as an adversarial attack, beneficial to what we want and is typical since we train with heavy augmentations.

import torch
from torchvision.transforms.functional import to_pil_image

# These will need to change depending on the dataset you use. Best practice is to calculate since online
# resources can sometimes provide bad data (often you'll see ImageNet results and not your specific dataset)
IMAGENET_DEFAULT_MEAN = torch.Tensor([0.485, 0.456, 0.406])
IMAGENET_DEFAULT_STD = torch.Tensor([0.229, 0.224, 0.225])

def batch_salient(model,
                  imgs,
                  mean=IMAGENET_DEFAULT_MEAN,
                  std=IMAGENET_DEFAULT_STD,
                  rounds=100,
                  noise_std=0.1,
                  noise_mean=0,
                  ):
    for i in range(rounds + 1):
        noise = torch.randn(imgs.size()) * noise_std + noise_mean
        noise = noise.to(imgs.device)
        imgs.requires_grad_()
        salient = None
        # Book-keeping so we maintain the root image for super-imposing the salient onto the original image
        if i == 0:
            preds = model(imgs)
            preds_orig = preds.clone()
        else:
            preds = model(imgs + noise)
        scores, indices = torch.max(preds, dim=1)
        scores.backward(torch.ones_like(scores))
        if salient is None:
            salient = torch.max(imgs.grad.data, dim=1)[0]
        else:
            salient += torch.max(imgs.grad.data, dim=1)[0]
    # This next line is optional and just normalizes the result. 
    salient /= rounds
    salient.relu_()
    salients = [to_pil_image(s.cpu().squeeze(0)).convert("RGB") for s in salient]
    return preds_orig, salients

I'll close this comment when we add the code to the repo but I hope this is useful for now.

Thank you for your reply! I will try it.