GMvandeVen / continual-learning

PyTorch implementation of various methods for continual learning (XdG, EWC, SI, LwF, FROMP, DGR, BI-R, ER, A-GEM, iCaRL, Generative Classifier) in three different scenarios.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

why batch_size has to be 1 when update fisher?

yuchenlin opened this issue · comments

Hi,

Thanks for the great repo. I have a quick question about the computation in the Fisher Information Matrix update: does the batch_size have to be 1 for the dataloader here:

data_loader = utils.get_data_loader(dataset, batch_size=1, cuda=self._is_on_cuda(), collate_fn=collate_fn)
? My main concern is about the speed here. Is that equivalent if I use a larger batch size?

Thank you so much in advance! :D

Fisher Information is taken as expectation over the variance of score per sample. PyTorch currently doesn't support batch-computation of gradients per sample, hence the batch size of 1.

You can compute batch gradients per sample, by batching forward pass, but disabling loss reduction to mean and backward() each loss value of the batch.

Conversely, limit the number of samples used to compute Fisher (say 256), which is good for most use-cases.