alexrame / fishr

Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Expected performance for the DomainBed implementation

cwognum opened this issue · comments

Hi @alexrame,

I've been playing around with the Fishr implementation in DomainBed for a while now and going through the paper again today I noticed the following section:

For example, on PACS dataset (7 classes and |ω| = 14, 343) with a ResNet-50 and batch size 32, Fishr induces an overhead in memory of +0.2% and in training time of +2.7% (with a Tesla V100) compared to ERM; on the larger-scale DomainNet (345 classes and |ω| = 706, 905), the overhead is +7.0% in memory and +6.5% in training time.

Unfortunately, I'm not noticing these kind of percentages for the overhead. Rather, Fishr is about twice as slow for me as a similar ERM model. Now this could be due to a variety of reasons. For example, my model is quite small because of which the overhead might play a relatively lager role: I simply use an MLP with layer of size [166, 1024, 256, 64, 256, 64, 1]. The backwards pass with Backpack is only for the last layer (so I believe |ω| = 64?). My batch size is 64.

I am writing to you to verify that the numbers from the paper were achieved with this implementation. Is that indeed the case? If so, do you think that in my case a larger overhead is to be expected? This would help me to narrow down a possible problem. I would greatly appreciate your response.

Thank you for your interest.
Yes, these numbers were obtained with a similar implementation, with default hyperparameters from DomainBed and - most importantly - with the default architecture, a ResNet-50. In this case, the overhead (in the last layer) is relatively very cheap compared to the computation in the features extractor - in memory and time with a GPU. In your case, your relatively-larger overhead is certainly due to your smaller architecture - and perhaps also a little by your larger batch size (64 rather than 32 in my computation). I hope this answer helps you.

Thank you for taking the time to respond. That provides me with some more certainty that I'm doing the right thing here! Much appreciated.