warner-benjamin / fastxtend

Train fastai models faster (and other useful tools)

Home Page:https://fastxtend.benjaminwarner.dev

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Passing aug_transforms to the DataBlock with CutMixUpAugment and ProgressiveResize leads to tensor size mismatch

csaroff opened this issue · comments

I'm running fastxtend version 0.0.18 and fastai version 2.7.11.

My dataloader init, model init and training code is a bit fragmented, but I'll try to piece it back together here for reproducibility.

It looks something like this:

db = DataBlock(
    blocks=blocks,
    getters=getters,
    n_inp=1,
    item_tfms=Resize(resize),
    batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)],
    splitter=splitter,
)
dls = db.dataloaders(df, bs=bs)
cbs = [ProgressiveResize(), CutMixUpAugment()]
learn = vision_learner(dls, arch=resnet50, opt_func=opt_func, loss_func=loss_func, cbs=cbs, metrics=metrics)
learn._fine_tune(epochs=epochs, base_lr=base_lr)
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ <ipython-input-23-4a152cf4a778>:1 in <cell line: 1>                                              │
│ <ipython-input-8-78506829a822>:20 in mlbl_fine_tune                                              │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/callback/schedule.py:168 in         │
│ fine_tune                                                                                        │
│                                                                                                  │
│   165self.fit_one_cycle(freeze_epochs, slice(base_lr), pct_start=0.99, **kwargs)            │
│   166base_lr /= 2                                                                           │
│   167self.unfreeze()                                                                        │
│ ❱ 168self.fit_one_cycle(epochs, slice(base_lr/lr_mult, base_lr), pct_start=pct_start, div   │
│   169                                                                                            │
│   170 # %% ../../nbs/14_callback.schedule.ipynb 67                                               │171 @docs                                                                                      │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/callback/schedule.py:119 in         │
│ fit_one_cycle                                                                                    │
│                                                                                                  │
│   116lr_max = np.array([h['lr'] for h in self.opt.hypers])                                  │
│   117scheds = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final),         │
│   118 │   │   │     'mom': combined_cos(pct_start, *(self.moms if moms is None else moms))}      │
│ ❱ 119self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd, sta   │
│   120                                                                                            │
│   121 # %% ../../nbs/14_callback.schedule.ipynb 50                                               │122 @patch                                                                                     │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:264 in fit               │
│                                                                                                  │
│   261 │   │   │   if wd is not None: self.opt.set_hypers(wd=wd)                                  │
│   262 │   │   │   self.opt.set_hypers(lr=self.lr if lr is None else lr)                          │
│   263 │   │   │   self.n_epoch = n_epoch                                                         │
│ ❱ 264 │   │   │   self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup   │
│   265 │                                                                                          │
│   266def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(   │
│   267def __enter__(self): self(_before_epoch); return self                                  │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:199 in _with_events      │
│                                                                                                  │
│   196 │   │   self.xb,self.yb = b[:i],b[i:]                                                      │
│   197 │                                                                                          │
│   198def _with_events(self, f, event_type, ex, final=noop):                                 │
│ ❱ 199 │   │   try: self(f'before_{event_type}');  f()                                            │
│   200 │   │   except ex: self(f'after_cancel_{event_type}')                                      │
│   201 │   │   self(f'after_{event_type}');  final()                                              │
│   202                                                                                            │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:253 in _do_fit           │
│                                                                                                  │
│   250def _do_fit(self):                                                                     │
│   251 │   │   for epoch in range(self.n_epoch):                                                  │
│   252 │   │   │   self.epoch=epoch                                                               │
│ ❱ 253 │   │   │   self._with_events(self._do_epoch, 'epoch', CancelEpochException)               │
│   254 │                                                                                          │
│   255def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False, start_epoch=0):    │
│   256 │   │   if start_epoch != 0:                                                               │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:199 in _with_events      │
│                                                                                                  │
│   196 │   │   self.xb,self.yb = b[:i],b[i:]                                                      │
│   197 │                                                                                          │
│   198def _with_events(self, f, event_type, ex, final=noop):                                 │
│ ❱ 199 │   │   try: self(f'before_{event_type}');  f()                                            │
│   200 │   │   except ex: self(f'after_cancel_{event_type}')                                      │
│   201 │   │   self(f'after_{event_type}');  final()                                              │
│   202                                                                                            │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:247 in _do_epoch         │
│                                                                                                  │
│   244 │   │   with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelVali   │
│   245 │                                                                                          │
│   246def _do_epoch(self):                                                                   │
│ ❱ 247 │   │   self._do_epoch_train()                                                             │
│   248 │   │   self._do_epoch_validate()                                                          │
│   249 │                                                                                          │
│   250def _do_fit(self):                                                                     │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:239 in _do_epoch_train   │
│                                                                                                  │
│   236 │                                                                                          │
│   237def _do_epoch_train(self):                                                             │
│   238 │   │   self.dl = self.dls.train                                                           │
│ ❱ 239 │   │   self._with_events(self.all_batches, 'train', CancelTrainException)                 │
│   240 │                                                                                          │
│   241def _do_epoch_validate(self, ds_idx=1, dl=None):                                       │
│   242 │   │   if dl is None: dl = self.dls[ds_idx]                                               │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:199 in _with_events      │
│                                                                                                  │
│   196 │   │   self.xb,self.yb = b[:i],b[i:]                                                      │
│   197 │                                                                                          │
│   198def _with_events(self, f, event_type, ex, final=noop):                                 │
│ ❱ 199 │   │   try: self(f'before_{event_type}');  f()                                            │
│   200 │   │   except ex: self(f'after_cancel_{event_type}')                                      │
│   201 │   │   self(f'after_{event_type}');  final()                                              │
│   202                                                                                            │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastxtend/callback/simpleprofiler.py:85 in │
│ all_batches                                                                                      │
│                                                                                                  │
│    82 │   │   │   self.one_batch(i, next(self.it))                                               │
│    83 │   │   del(self.it)                                                                       │
│    84else:                                                                                  │
│ ❱  85 │   │   for o in enumerate(self.dl): self.one_batch(*o)                                    │
│    86                                                                                            │
│    87 # %% ../../nbs/callback.simpleprofiler.ipynb 18                                            │88 _loop = ['Start Fit', 'before_fit', 'Start Epoch Loop', 'before_epoch', 'Start Train', '   │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:235 in one_batch         │
│                                                                                                  │
│   232 │   │   self.iter = i                                                                      │
│   233 │   │   b = self._set_device(b)                                                            │
│   234 │   │   self._split(b)                                                                     │
│ ❱ 235 │   │   self._with_events(self._do_one_batch, 'batch', CancelBatchException)               │
│   236 │                                                                                          │
│   237def _do_epoch_train(self):                                                             │
│   238 │   │   self.dl = self.dls.train                                                           │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:199 in _with_events      │
│                                                                                                  │
│   196 │   │   self.xb,self.yb = b[:i],b[i:]                                                      │
│   197 │                                                                                          │
│   198def _with_events(self, f, event_type, ex, final=noop):                                 │
│ ❱ 199 │   │   try: self(f'before_{event_type}');  f()                                            │
│   200 │   │   except ex: self(f'after_cancel_{event_type}')                                      │
│   201 │   │   self(f'after_{event_type}');  final()                                              │
│   202                                                                                            │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/learner.py:172 in __call__          │
│                                                                                                  │
│   169 │   │   finally: self.add_cbs(cbs)                                                         │
│   170 │                                                                                          │
│   171def ordered_cbs(self, event): return [cb for cb in self.cbs.sorted('order') if hasat   │
│ ❱ 172def __call__(self, event_name): L(event_name).map(self._call_one)                      │
│   173 │                                                                                          │
│   174def _call_one(self, event_name):                                                       │
│   175 │   │   if not hasattr(event, event_name): raise Exception(f'missing {event_name}')        │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastcore/foundation.py:156 in map          │
│                                                                                                  │
│   153 │   @classmethod                                                                           │
│   154def range(cls, a, b=None, step=None): return cls(range_of(a, b=b, step=step))          │
│   155 │                                                                                          │
│ ❱ 156def map(self, f, *args, gen=False, **kwargs): return self._new(map_ex(self, f, *args   │
│   157def argwhere(self, f, negate=False, **kwargs): return self._new(argwhere(self, f, ne   │
│   158def argfirst(self, f, negate=False): return first(i for i,o in self.enumerate() if f   │
│   159def filter(self, f=noop, negate=False, gen=False, **kwargs):                           │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastcore/basics.py:840 in map_ex           │
│                                                                                                  │
│    837 │   │    else f.__getitem__)                                                              │
│    838res = map(g, iterable)                                                                │
│    839if gen: return res                                                                    │
│ ❱  840return list(res)                                                                      │
│    841                                                                                           │
│    842 # %% ../nbs/01_basics.ipynb 336                                                           │843 def compose(*funcs, order=None):                                                          │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastcore/basics.py:825 in __call__         │
│                                                                                                  │
│    822 │   │   for k,v in kwargs.items():                                                        │
│    823 │   │   │   if isinstance(v,_Arg): kwargs[k] = args.pop(v.i)                              │
│    824 │   │   fargs = [args[x.i] if isinstance(x, _Arg) else x for x in self.pargs] + args[sel  │
│ ❱  825 │   │   return self.func(*fargs, **kwargs)                                                │
│    826                                                                                           │
│    827 # %% ../nbs/01_basics.ipynb 326                                                           │828 def mapt(func, *iterables):                                                               │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastxtend/callback/simpleprofiler.py:72 in │
│ _call_one                                                                                        │
│                                                                                                  │
│    69 @patch                                                                                     │
│    70 def _call_one(self:Learner, event_name):                                                   │
│    71if not hasattr(event, event_name): raise Exception(f'missing {event_name}')            │
│ ❱  72for cb in self.cbs.sorted('order'): cb(event_name)                                     │
│    73                                                                                            │
│    74 # %% ../../nbs/callback.simpleprofiler.ipynb 16                                            │75 @patch                                                                                     │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastxtend/callback/simpleprofiler.py:60 in │
│ __call__                                                                                         │
│                                                                                                  │
│    57 │   │   │   (self.run_valid and not getattr(self, 'training', False)))                     │
│    58res = None                                                                             │
│    59if self.run and _run:                                                                  │
│ ❱  60 │   │   try: res = getattr(self, event_name, noop)()                                       │
│    61 │   │   except (CancelBatchException, CancelEpochException, CancelFitException, CancelSt   │
│    62 │   │   except Exception as e:                                                             │
│    63 │   │   │   e.args = [f'Exception occured in `{self.__class__.__name__}` when calling ev   │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastxtend/callback/cutmixup.py:292 in      │
│ before_batch                                                                                     │
│                                                                                                  │
│   289 │   │   │   # Apply MixUp/CutMix Augmentations to MixUp and CutMix samples                 │290 │   │   │   if do_mix or do_cut:                                                           │
│   291 │   │   │   │   if self._docutmixaug:                                                      │
│ ❱ 292 │   │   │   │   │   xb2[aug_type<2] = self._cutmixaugs_pipe(xb[aug_type<2])                │
│   293 │   │   │   │   else:                                                                      │
│   294 │   │   │   │   │   xb2[aug_type<2] = xb[aug_type<2]                                       │
│   295                                                                                            │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/fastai/torch_core.py:372 in                │
│ __torch_function__                                                                               │
│                                                                                                  │
│   369def __torch_function__(cls, func, types, args=(), kwargs=None):                        │
│   370 │   │   if cls.debug and func.__name__ not in ('__str__','__repr__'): print(func, types,   │
│   371 │   │   if _torch_handled(args, cls._opt, func): types = (torch.Tensor,)                   │
│ ❱ 372 │   │   res = super().__torch_function__(func, types, args, ifnone(kwargs, {}))            │
│   373 │   │   dict_objs = _find_args(args) if args else _find_args(list(kwargs.values()))        │
│   374 │   │   if issubclass(type(res),TensorBase) and dict_objs: res.set_meta(dict_objs[0],as_   │
│   375 │   │   elif dict_objs and is_listy(res): [r.set_meta(dict_objs[0],as_copy=True) for r i   │
│                                                                                                  │
│ /home/csaroff/.miniconda3/lib/python3.9/site-packages/torch/_tensor.py:1279 in                   │
│ __torch_function__                                                                               │
│                                                                                                  │
│   1276 │   │   │   return NotImplemented                                                         │
│   1277 │   │                                                                                     │
│   1278 │   │   with _C.DisableTorchFunction():                                                   │
│ ❱ 1279 │   │   │   ret = func(*args, **kwargs)                                                   │
│   1280 │   │   │   if func in get_default_nowrap_functions():                                    │
│   1281 │   │   │   │   return ret                                                                │
│   1282 │   │   │   else:                                                                         │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Exception occured in `CutMixUpAugment` when calling event `before_batch`:
        shape mismatch: value tensor of shape [21, 3, 224, 224] cannot be broadcast to indexing result of shape 
[21, 3, 80, 80]

Removing aug_transforms from batch_tfms in the DataBlock init fixed it for me.