Passing aug_transforms to the DataBlock with CutMixUpAugment and ProgressiveResize leads to tensor size mismatch
csaroff opened this issue · comments
I'm running fastxtend version 0.0.18 and fastai version 2.7.11.
My dataloader init, model init and training code is a bit fragmented, but I'll try to piece it back together here for reproducibility.
It looks something like this:
db = DataBlock(
blocks=blocks,
getters=getters,
n_inp=1,
item_tfms=Resize(resize),
batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)],
splitter=splitter,
)
dls = db.dataloaders(df, bs=bs)
cbs = [ProgressiveResize(), CutMixUpAugment()]
learn = vision_learner(dls, arch=resnet50, opt_func=opt_func, loss_func=loss_func, cbs=cbs, metrics=metrics)
learn._fine_tune(epochs=epochs, base_lr=base_lr)
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ <ipython-input-23-4a152cf4a778>:1 in <cell line: 1> │
│ <ipython-input-8-78506829a822>:20 in mlbl_fine_tune │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/callback/schedule.py:168 in │
│ fine_tune │
│ │
│ 165 │ self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs) │
│ 166 │ base_lr /= 2 │
│ 167 │ self.unfreeze() │
│ ❱ 168 │ self.fit_one_cycle(epochs, slice(base_lr/lr_mult, base_lr), pct_start=pct_start, div │
│ 169 │
│ 170 # %% ../../nbs/14_callback.schedule.ipynb 67 │
│ 171 @docs │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/callback/schedule.py:119 in │
│ fit_one_cycle │
│ │
│ 116 │ lr_max = np.array([h['lr'] for h in self.opt.hypers]) │
│ 117 │ scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final), │
│ 118 │ │ │ 'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))} │
│ ❱ 119 │ self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd, sta │
│ 120 │
│ 121 # %% ../../nbs/14_callback.schedule.ipynb 50 │
│ 122 @patch │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:264 in fit │
│ │
│ 261 │ │ │ if wd is not None: self.opt.set_hypers(wd=wd) │
│ 262 │ │ │ self.opt.set_hypers(lr=self.lr if lr is None else lr) │
│ 263 │ │ │ self.n_epoch = n_epoch │
│ ❱ 264 │ │ │ self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup │
│ 265 │ │
│ 266 │ def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),( │
│ 267 │ def __enter__(self): self(_before_epoch); return self │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:199 in _with_events │
│ │
│ 196 │ │ self.xb,self.yb = b[:i],b[i:] │
│ 197 │ │
│ 198 │ def _with_events(self, f, event_type, ex, final=noop): │
│ ❱ 199 │ │ try: self(f'before_{event_type}'); f() │
│ 200 │ │ except ex: self(f'after_cancel_{event_type}') │
│ 201 │ │ self(f'after_{event_type}'); final() │
│ 202 │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:253 in _do_fit │
│ │
│ 250 │ def _do_fit(self): │
│ 251 │ │ for epoch in range(self.n_epoch): │
│ 252 │ │ │ self.epoch=epoch │
│ ❱ 253 │ │ │ self._with_events(self._do_epoch, 'epoch', CancelEpochException) │
│ 254 │ │
│ 255 │ def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False, start_epoch=0): │
│ 256 │ │ if start_epoch != 0: │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:199 in _with_events │
│ │
│ 196 │ │ self.xb,self.yb = b[:i],b[i:] │
│ 197 │ │
│ 198 │ def _with_events(self, f, event_type, ex, final=noop): │
│ ❱ 199 │ │ try: self(f'before_{event_type}'); f() │
│ 200 │ │ except ex: self(f'after_cancel_{event_type}') │
│ 201 │ │ self(f'after_{event_type}'); final() │
│ 202 │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:247 in _do_epoch │
│ │
│ 244 │ │ with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelVali │
│ 245 │ │
│ 246 │ def _do_epoch(self): │
│ ❱ 247 │ │ self._do_epoch_train() │
│ 248 │ │ self._do_epoch_validate() │
│ 249 │ │
│ 250 │ def _do_fit(self): │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:239 in _do_epoch_train │
│ │
│ 236 │ │
│ 237 │ def _do_epoch_train(self): │
│ 238 │ │ self.dl = self.dls.train │
│ ❱ 239 │ │ self._with_events(self.all_batches, 'train', CancelTrainException) │
│ 240 │ │
│ 241 │ def _do_epoch_validate(self, ds_idx=1, dl=None): │
│ 242 │ │ if dl is None: dl = self.dls[ds_idx] │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:199 in _with_events │
│ │
│ 196 │ │ self.xb,self.yb = b[:i],b[i:] │
│ 197 │ │
│ 198 │ def _with_events(self, f, event_type, ex, final=noop): │
│ ❱ 199 │ │ try: self(f'before_{event_type}'); f() │
│ 200 │ │ except ex: self(f'after_cancel_{event_type}') │
│ 201 │ │ self(f'after_{event_type}'); final() │
│ 202 │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastxtend/callback/simpleprofiler.py:85 in │
│ all_batches │
│ │
│ 82 │ │ │ self.one_batch(i, next(self.it)) │
│ 83 │ │ del(self.it) │
│ 84 │ else: │
│ ❱ 85 │ │ for o in enumerate(self.dl): self.one_batch(*o) │
│ 86 │
│ 87 # %% ../../nbs/callback.simpleprofiler.ipynb 18 │
│ 88 _loop = ['Start Fit', 'before_fit', 'Start Epoch Loop', 'before_epoch', 'Start Train', ' │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:235 in one_batch │
│ │
│ 232 │ │ self.iter = i │
│ 233 │ │ b = self._set_device(b) │
│ 234 │ │ self._split(b) │
│ ❱ 235 │ │ self._with_events(self._do_one_batch, 'batch', CancelBatchException) │
│ 236 │ │
│ 237 │ def _do_epoch_train(self): │
│ 238 │ │ self.dl = self.dls.train │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:199 in _with_events │
│ │
│ 196 │ │ self.xb,self.yb = b[:i],b[i:] │
│ 197 │ │
│ 198 │ def _with_events(self, f, event_type, ex, final=noop): │
│ ❱ 199 │ │ try: self(f'before_{event_type}'); f() │
│ 200 │ │ except ex: self(f'after_cancel_{event_type}') │
│ 201 │ │ self(f'after_{event_type}'); final() │
│ 202 │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:172 in __call__ │
│ │
│ 169 │ │ finally: self.add_cbs(cbs) │
│ 170 │ │
│ 171 │ def ordered_cbs(self, event): return [cb for cb in self.cbs.sorted('order') if hasat │
│ ❱ 172 │ def __call__(self, event_name): L(event_name).map(self._call_one) │
│ 173 │ │
│ 174 │ def _call_one(self, event_name): │
│ 175 │ │ if not hasattr(event, event_name): raise Exception(f'missing {event_name}') │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastcore/foundation.py:156 in map │
│ │
│ 153 │ @classmethod │
│ 154 │ def range(cls, a, b=None, step=None): return cls(range_of(a, b=b, step=step)) │
│ 155 │ │
│ ❱ 156 │ def map(self, f, *args, gen=False, **kwargs): return self._new(map_ex(self, f, *args │
│ 157 │ def argwhere(self, f, negate=False, **kwargs): return self._new(argwhere(self, f, ne │
│ 158 │ def argfirst(self, f, negate=False): return first(i for i,o in self.enumerate() if f │
│ 159 │ def filter(self, f=noop, negate=False, gen=False, **kwargs): │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastcore/basics.py:840 in map_ex │
│ │
│ 837 │ │ else f.__getitem__) │
│ 838 │ res = map(g, iterable) │
│ 839 │ if gen: return res │
│ ❱ 840 │ return list(res) │
│ 841 │
│ 842 # %% ../nbs/01_basics.ipynb 336 │
│ 843 def compose(*funcs, order=None): │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastcore/basics.py:825 in __call__ │
│ │
│ 822 │ │ for k,v in kwargs.items(): │
│ 823 │ │ │ if isinstance(v,_Arg): kwargs[k] = args.pop(v.i) │
│ 824 │ │ fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[sel │
│ ❱ 825 │ │ return self.func(*fargs, **kwargs) │
│ 826 │
│ 827 # %% ../nbs/01_basics.ipynb 326 │
│ 828 def mapt(func, *iterables): │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastxtend/callback/simpleprofiler.py:72 in │
│ _call_one │
│ │
│ 69 @patch │
│ 70 def _call_one(self:Learner, event_name): │
│ 71 │ if not hasattr(event, event_name): raise Exception(f'missing {event_name}') │
│ ❱ 72 │ for cb in self.cbs.sorted('order'): cb(event_name) │
│ 73 │
│ 74 # %% ../../nbs/callback.simpleprofiler.ipynb 16 │
│ 75 @patch │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastxtend/callback/simpleprofiler.py:60 in │
│ __call__ │
│ │
│ 57 │ │ │ (self.run_valid and not getattr(self, 'training', False))) │
│ 58 │ res = None │
│ 59 │ if self.run and _run: │
│ ❱ 60 │ │ try: res = getattr(self, event_name, noop)() │
│ 61 │ │ except (CancelBatchException, CancelEpochException, CancelFitException, CancelSt │
│ 62 │ │ except Exception as e: │
│ 63 │ │ │ e.args = [f'Exception occured in `{self.__class__.__name__}` when calling ev │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastxtend/callback/cutmixup.py:292 in │
│ before_batch │
│ │
│ 289 │ │ │ # Apply MixUp/CutMix Augmentations to MixUp and CutMix samples │
│ 290 │ │ │ if do_mix or do_cut: │
│ 291 │ │ │ │ if self._docutmixaug: │
│ ❱ 292 │ │ │ │ │ xb2[aug_type<2] = self._cutmixaugs_pipe(xb[aug_type<2]) │
│ 293 │ │ │ │ else: │
│ 294 │ │ │ │ │ xb2[aug_type<2] = xb[aug_type<2] │
│ 295 │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/torch_core.py:372 in │
│ __torch_function__ │
│ │
│ 369 │ def __torch_function__(cls, func, types, args=(), kwargs=None): │
│ 370 │ │ if cls.debug and func.__name__ not in ('__str__','__repr__'): print(func, types, │
│ 371 │ │ if _torch_handled(args, cls._opt, func): types = (torch.Tensor,) │
│ ❱ 372 │ │ res = super().__torch_function__(func, types, args, ifnone(kwargs, {})) │
│ 373 │ │ dict_objs = _find_args(args) if args else _find_args(list(kwargs.values())) │
│ 374 │ │ if issubclass(type(res),TensorBase) and dict_objs: res.set_meta(dict_objs[0],as_ │
│ 375 │ │ elif dict_objs and is_listy(res): [r.set_meta(dict_objs[0],as_copy=True) for r i │
│ │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/torch/_tensor.py:1279 in │
│ __torch_function__ │
│ │
│ 1276 │ │ │ return NotImplemented │
│ 1277 │ │ │
│ 1278 │ │ with _C.DisableTorchFunction(): │
│ ❱ 1279 │ │ │ ret = func(*args, **kwargs) │
│ 1280 │ │ │ if func in get_default_nowrap_functions(): │
│ 1281 │ │ │ │ return ret │
│ 1282 │ │ │ else: │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Exception occured in `CutMixUpAugment` when calling event `before_batch`:
shape mismatch: value tensor of shape [21, 3, 224, 224] cannot be broadcast to indexing result of shape
[21, 3, 80, 80]
Removing aug_transforms from batch_tfms in the DataBlock init fixed it for me.