Compute pixel-wise bitrate allocation in latent space
danishnazir opened this issue · comments
Hi,
Is there a way to compute per-pixel bitrate info in RGB space of the latent variable y or z during evaluation?
Regards
The likelihood map
In bmshj2018-factorized, each
I guess a more advanced method for estimating rate costs of encoding each pixel would involve training some sort of rate estimation model or maybe some Grad-CAM-like technique.
Thank you for your answer. I try to do it in the following way, can you maybe please point out any mistake?
model.eval() #model = hyperprior
model.update()
y = model.g_a(x)
y_hat, likelihood_map = model.entropy_bottleneck(y)
likelihood_map = likelihood_map [0].detach().cpu().numpy()
pixel_bitrates = likelihood_map.sum(dim=1) #channel-wise sum
Now should i just upsample pixel_bitrates
to x
using the interpolation techniques like billinear
somehow?
bmshj2018-factorized NLLs (negative log likelihoods):
import matplotlib.pyplot as plt
import torch.nn.functional as F
from compressai.zoo import bmshj2018_factorized
from PIL import Image
from torchvision import transforms
device = "cuda"
for quality in [1, 2, 3, 4, 5, 6, 7, 8]:
model = bmshj2018_factorized(quality=quality, pretrained=True).eval().to(device)
img = Image.open("/data/datasets/kodak/test/kodim01.png").convert("RGB")
x = transforms.ToTensor()(img).unsqueeze(0).to(device)
_, _, H, W = x.shape
y = model.g_a(x)
y_hat, y_likelihoods = model.entropy_bottleneck(y)
y_nll = -y_likelihoods.log2()
scale_factor = x.shape[-1] / y.shape[-1]
x_nll = F.interpolate(
y_nll.sum(dim=1, keepdim=True) / scale_factor**2,
scale_factor=scale_factor,
mode="bilinear",
align_corners=False,
)
fig, ax = plt.subplots(tight_layout=True)
im = ax.imshow(x_nll[0, 0].detach().cpu().numpy(), vmin=0)
fig.colorbar(im)
ax.set(title=f"bmshj2018-factorized q={quality}")
fig.savefig(f"x_nll_q={quality}.png")
plt.close(fig)
Just for fun, to animate:
ffmpeg -framerate 1 -pattern_type glob -i 'x_nll_q*.png' -f apng -plays 0 x_nll_all.png
Input | Negative log likelihoods (bits) |
---|---|
Low frequency regions consume the least rate.