google / uncertainty-baselines

High-quality implementations of standard and SOTA methods on a variety of tasks.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Couldn't reproduce MIMO accuracy on CIFAR-100

alexlyzhov opened this issue · comments

Hi @dustinvtran and others, thank you for the repository! @VoronkovaDasha and me are trying to reproduce results of MIMO on WideResNet28x10 + CIFAR-100 to compare the performance with other methods. However, so far we have not been able to do it; the accuracy values we get are a notch lower than they should be. We use 1 GPU.

For CIFAR-100 the paper reports accuracy 82.0, NLL 0.690 for an ensemble of size 3.

Here is what we get:

python3 /dashavoronkova8/ens_project/mimo/cifar.py --output_dir '/dashavoronkova8/ens_project/mimo/cifar' --seed 0 --use_gpu --dataset cifar100 --per_core_batch_size 512 --num_cores 1 --batch_repetitions 4 --corruptions_interval -1 --ensemble_size 3 --width_multiplier 10 --base_learning_rate 0.1 --train_epochs 250 --lr_decay_ratio 0.1 --lr_warmup_epochs 0 --num_bins 15 --input_repetition_probability 0. --l2 3e-4 --checkpoint_interval 50

Train Loss: 1.3752, Accuracy: 99.94%
Test NLL: 0.7143, Accuracy: 80.85%
Member 0 Test Loss: 0.9081, Accuracy: 77.92%
Member 1 Test Loss: 0.9205, Accuracy: 77.65%
Member 2 Test Loss: 0.9248, Accuracy: 77.64%

The same experiment with another seed:

Train Loss: 1.3718, Accuracy: 99.95%
Test NLL: 0.7147, Accuracy: 80.73%
Member 0 Test Loss: 0.9152, Accuracy: 77.83%
Member 1 Test Loss: 0.9257, Accuracy: 77.55%
Member 2 Test Loss: 0.9209, Accuracy: 77.52%

Now with lr_warmup_epochs=1.

python3 /dashavoronkova8/ens_project/mimo/cifar.py --output_dir '/dashavoronkova8/ens_project/mimo/cifar' --seed 0 --use_gpu --dataset cifar100 --per_core_batch_size 512 --num_cores 1 --batch_repetitions 4 --corruptions_interval -1 --ensemble_size 3 --width_multiplier 10 --base_learning_rate 0.1 --train_epochs 250 --lr_decay_ratio 0.1 --lr_warmup_epochs 1 --num_bins 15 --input_repetition_probability 0. --l2 3e-4 --checkpoint_interval 50

Train Loss: 1.3739, Accuracy: 99.95%
Test NLL: 0.7198, Accuracy: 80.76%
Member 0 Test Loss: 0.9486, Accuracy: 77.09%
Member 1 Test Loss: 0.9144, Accuracy: 77.73%
Member 2 Test Loss: 0.9117, Accuracy: 77.74%

I wonder what is the culprit here? Are the script parameters OK?

@mhavasi

Got back from vacation (happy new year!).

Looking into the baseline, the default flag values we use in the CIFAR script are indeed the ones used for reporting the #s in the CIFAR-100 leaderboard. The only flags we override in commandline is: dataset=cifar100,cifar100_c_path=$CIFAR100_C_PATH .

You mention using only 1 GPU, which sounds like the likely culprit. Have you verified that the global batch size you're using for each gradient step is equivalent to the default script's setup? This is FLAGS.num_cores * FLAGS.per_core_batch_size.

Happy new year!

When we run python3 /dashavoronkova8/ens_project/mimo/cifar.py --output_dir '/dashavoronkova8/ens_project/mimo/cifar' --seed 0 --use_gpu --dataset cifar100 --per_core_batch_size 512 --num_cores 1 --batch_repetitions 4 --corruptions_interval -1 --ensemble_size 3 --width_multiplier 10 --base_learning_rate 0.1 --train_epochs 250 --lr_decay_ratio 0.1 --lr_warmup_epochs 1 --num_bins 15 --input_repetition_probability 0. --l2 3e-4 --checkpoint_interval 50 and get 80.76% accuracy, our deviations from the default parameters are:

  • ensemble_size 3 (vs default 4), according to the paper the accuracy is even better at this value
  • per_core_batch_size 512 (vs default 64)
  • num_cores 1 (vs default 8), along with per_core_batch_size=512 this results in total batch size of 512, as in the paper and in the original script
  • use_gpu True (vs default False), because we don't use a TPU
  • lr_decay_ratio 0.1 (vs default 0.2), as reported in the paper (is 0.2 better to use?)
  • corruptions_interval -1 (vs default 250), to skip evaluation on corrupted data

Hi, thanks for the interest in the paper!

It could be the lr_decay_ratio. The one in the code (lr_decay_ratio=0.2) is definitely the correct one. It's possible that I put the incorrect value in the paper, I will update it ASAP.

As Dustin said, the results that we report come from running the code flags with dataset=cifar100,cifar100_c_path=$CIFAR100_C_PATH with seed varying from 0 to 9 and averaging the 10 runs.

Did you also try cifar10? Did it work as expected?

With CIFAR-10 we had the accuracy of 96.17% vs 96.40% in the paper, with the same parameters as for CIFAR-100 - should be outside the standard deviation range. Thanks, we'll try the 0.2 value!

Another potential issue is that in the default setup, batch norm is applied separately over the 64 data points per core. BN across the full large batch size removes the implicit regularization benefits which can sometimes be important.

ensemble_size 3 (vs default 4), according to the paper the accuracy is even better at this value
lr_decay_ratio 0.1 (vs default 0.2), as reported in the paper (is 0.2 better to use?)

That's a good catch. I relaunched the code and here's the results, each averaged over 10 seeds:

default flags + ensemble_size=3,dataset=cifar100,cifar100_c_path=$CIFAR100_C_PATH

test_log_likelihood test_accuracy test_ece test_nll_mean_corrupted test_accuracy_mean_corrupted test_ece_mean_corrupted
-0.683279 0.8195 0.020577 2.285121 0.537657 0.129027

default flags + dataset=cifar100,cifar100_c_path=$CIFAR100_C_PATH

test_log_likelihood test_accuracy test_ece test_nll_mean_corrupted test_accuracy_mean_corrupted test_ece_mean_corrupted
-0.685981 0.815759 0.021685 2.25507 0.533822 0.112056

default flags + ensemble_size=3 (so cifar-10)

  test_log_likelihood test_accuracy test_ece test_nll_mean_corrupted test_accuracy_mean_corrupted test_ece_mean_corrupted
-0.125419 0.96287 0.011104 0.909436 0.767745 0.107714

So yeah, the table for both cifar datasets is in fact reported with ensemble_size=3. lr_decay_ratio=0.2 as per default. I sent a PR to fix ensemble_size's default (#269).

Closing, feel free to reopen if needed!