Regarding JointAutoRegressiveHierarchicalPriors implementation (mbt2018)
Indraa145 opened this issue · comments
Hello, I have a question regarding the implementation of JointAutoRegressiveHierarchicalPriors. I noticed there is something off in the forward() function.
y_hat = self.gaussian_conditional.quantize( # Why do you use this as input of self.g_s
y, "noise" if self.training else "dequantize"
)
ctx_params = self.context_prediction(y_hat)
gaussian_params = self.entropy_parameters(
torch.cat((params, ctx_params), dim=1)
)
scales_hat, means_hat = gaussian_params.chunk(2, 1)
_, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat) # Instead of using this one for the input of self.g_s (you leave the variable as "_")
x_hat = self.g_s(y_hat)
Why do you use the quantized y_hat without the means_hat as the input of g_s, instead of the quantized y_hat with the means_hat (you leave the variable as "_").
Inside the self.gaussian_conditional forward() function
outputs = self.quantize(inputs, "noise" if training else "dequantize", means)
Inside the self.quantize function
outputs = inputs.clone()
if means is not None:
outputs -= means
Wouldn't the y_hat that should be used as input for the self.g_s be different then? As the quantized y_hat that you used is not reduced by the means_hat as shown in the code. Considering that you use the y_likelihoods with the means_hat as well.
That's what I can gather from reading that part of the code, sorry if I misunderstand something, thank you.
- During training, both methods use noise-based quantization and ignore the means parameter.
- During validation/inference, the first method is indeed different from the second.
Due to (2), it might be a good idea to fix this for more precise validation/inference.