x-zho14 / ProbMask-official

Implementation of Effective Sparsification of Neural Networks with Global Sparsity Constraint

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

BN-tuning on test-set

vanderschuea opened this issue · comments

Hi,

I noticed lower results when fixing this issue in your Github:
The argument use_running_stats is by default set to False (and no available config overrides this). Setting it to False will tune the BN-statistics (moving_mean, moving_var) on the test-set during each validation epoch. This introduces on unwelcome bias when evaluating model-performance since it injects test-set information during training.

I suppose this was done to combat the issue of a randomly selected weight-mask during test-inference (which leads to near random performances). However another way to combat this issue would be to simply fix the last used binary mask during testing and thus avoid creating such biases. I've tested this approach and it allows the test-data predication to not be random without using test data during training.

If we suppose that inference_mask was initialized at the same size as the weights, adding this into the layer would be a solution

# Compute binary Mask
if self.training:
    mask = self._get_mask()
    with torch.no_grad():
        self.inference_mask[:] = mask.clone()
else:
    mask = self.inference_mask

# Convolution with masked W
......

Thanks for your suggestion. We set it as False because the bn stats are not stable during training, especially at the beginning part. It can also be resolved by setting use_running_stats as True (adding "--use_running_stats" to the command line argument works). The test-set information injection doesn't influence the final results since we finally sample a subnet and finetune on it. Therefore the bn stats are calculated by information over training and test-set information injection diminishes through exponential decay.