royerlab / cytoself

Self-supervised models for encoding protein localization patterns from microscopy images

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Variance normalization

sofroniewn opened this issue · comments

I'm trying to understand the variance normalization that happens here

self.model.mse_loss['reconstruction1_loss'] = mse_loss_fn(model_outputs[0], img) / variance

I can't seem to find a similar variance normalization discussed in the hierarchical VQ-VAE paper or other older papers, and it's not clear to me why normalization by a single scalar is needed here.

It's also not clear to me why you only normalize reconstruction1_loss, but reconstruction2_loss calculated here is left the same

self.mse_loss[f'reconstruction{len(self.decoders) - i}_loss'] = nn.MSELoss()(
Wouldn't it make sense to normalize both of them, given that they get summed into the same loss?

Also do you think it's really important to calculate the variance for train/ val/ test separately? Could one just use the variance across train for all? The difference doesn't seem to be large, and it makes things a little easier.

Any help here would be appreciated - thanks!!

Also the variance value I have is ~285, which seems quite high relative to the other loss numbers and means that the reconstruction1_loss seems to be having very little effect.

Opps, my bad, my variance is actually 0.01677 which matches what you have in your example and won't cause the reconstruction loss to be ignored, my bad. Still my initial questions remain.

I can't seem to find a similar variance normalization discussed in the hierarchical VQ-VAE paper or other older papers, and it's not clear to me why normalization by a single scalar is needed here.

I couldn't find a good explanation online either, although I think I did read something in the early time of this project.
I might have not thought too much about why and simply followed the original implementation.

Now looking back to the original VQ-VAE paper (particularly Fig. 1), I feel the purpose of variance normalization is to facilitate the optimization of the red vector in Fig. 1. I guess it might have been difficult to balance reconstruction loss among other losses without variance normalization (although it's not impossible to just let the model figure out the best way). If that's the case, the value of the variance shouldn't matter as long as the optimization can go on smoothly.

Wouldn't it make sense to normalize both of them, given that they get summed into the same loss?

If the purpose of variance normalization is to facilitate optimization, then yes, reconstruction2_loss should be normalized in one way or another. I didn't normalize it I guess because I didn't know what should I use for the normalization (whether the variance of the original input data or the latent vector).

Also do you think it's really important to calculate the variance for train/ val/ test separately? Could one just use the variance across train for all? The difference doesn't seem to be large, and it makes things a little easier.

If my guess above was correct, then yes we don't need to worry about the accurate value of any variance. I guess in practice we just need to estimate an approximate number that falls in the same order of magnitude of the real training variance (this may vary up to the dataset). And we don't need to calculate variance for val and test.

feel the purpose of variance normalization is to facilitate the optimization of the red vector in Fig. 1. I guess it might have been difficult to balance reconstruction loss among other losses without variance normalization (although it's not impossible to just let the model figure out the best way). If that's the case, the value of the variance shouldn't matter as long as the optimization can go on smoothly.

Ok yeah - this makes sense, it's just a nice way to get some rough scale invariance, e.g. mse / variance.

then yes, reconstruction2_loss should be normalized in one way or another. I didn't normalize it I guess because I didn't know what should I use for the normalization (whether the variance of the original input data or the latent vector).

hmm ok yeah, it might make sense to still use variance of the original data. i might give that i try. I can let you know if i see any performance differences. I suspect they will be minor.

I guess in practice we just need to estimate an approximate number that falls in the same order of magnitude of the real training variance (this may vary up to the dataset). And we don't need to calculate variance for val and test.

Yeah, I think I will just use one precomputed rough number as it makes things a lot faster and easier.

Thanks for all the input here, I'll let you know if I see anything important as I try things out!!