How can I understand imposing an absolute value and a negative sign on a random variable before calculating the likelihood? Thank you very much for your answer! ! !
zhengxinChenee opened this issue · comments
As stated, we can calculate the likelihood probability of the quantized latent variable by CDF((values + half)/ scales) - CDF((values - half)/ scales), but in CompressAI, the absolute value and negative sign are applied to the quantized latent variable before using cdf to calculate the likelihood probability, i.e.,
values = torch.abs(values)
upper = self._standardized_cumulative((half - values) / scales)
lower = self._standardized_cumulative((-half - values) / scales)
likelihood = upper - lower
How to understand the different implementation of CompressAI?
Proof by symmetry
Define:
def _standardized_cumulative(x):
return 0.5 * torch.erfc(-x / sqrt(2))
_standardized_cumulative
has various symmetries. Notably:
_standardized_cumulative(x) == 1 - _standardized_cumulative(-x)
Thus:
upper
== _standardized_cumulative((0.5 - samples) / scale)
== 1 - _standardized_cumulative((samples - 0.5) / scale)
lower
== _standardized_cumulative((-0.5 - samples) / scale)
== 1 - _standardized_cumulative((samples + 0.5) / scale)
upper - lower
== _standardized_cumulative((samples + 0.5) / scale) - _standardized_cumulative((samples - 0.5) / scale)
This is the correct result.
The same proof holds for _logits_cumulative
, by the way.
That said, it looks like the names are actually backwards, which makes it a bit confusing. (i.e., upper and lower should be swapped.) Maybe some renaming and comments are needed. I already fixed _logits_cumulative
or sigmoid
earlier, since sigmoid
said something similar concerning symmetries. #182
EDIT: Actually, that proof was irrelevant, since...
The names are not backwards, since:
multiplier = -self._standardized_quantile(self.tail_mass / 2)
pmf_center = torch.ceil(self.scale_table * multiplier).int()
samples = torch.abs(
torch.arange(max_length, device=device).int() - pmf_center[:, None]
)
assert self.tail_mass < 0.5 # Usually 1e-9
assert self.multiplier > 0 # Usually 21.7 or something like that
assert (self.scale_table > 0).all()
assert (self.pmf_center > 0).all()
And so, samples
is essentially "reversed". Thus,
upper
== _standardized_cumulative((0.5 - samples) / scale)
== _standardized_cumulative((0.5 + unreversed_samples) / scale)
lower
== _standardized_cumulative((-0.5 - samples) / scale)
== _standardized_cumulative((-0.5 + unreversed_samples) / scale)
That also matches our expectations for what upper
and lower
are.
Still, a bit tricky to prove inside one's head. Perhaps it should be rewritten so that -samples → samples
for clarity.