albumentations-team / autoalbument

AutoML for image augmentation. AutoAlbument uses the Faster AutoAugment algorithm to find optimal augmentation policies. Documentation - https://albumentations.ai/docs/autoalbument/

Home Page:https://albumentations.ai/docs/autoalbument/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ImportError within FasterAutoAugmentSearcher

adam-mehdi opened this issue · comments

Hello,

AutoAlbument's FasterAutoAugmentSearcher is yielding an import error cannot import name 'Batch' from 'torchtext.data' when searching for a policy. This bug recently occurred in PyTorch Lightning in general, and it is fixed by importing Lightning from github: pip install git+https://github.com/PyTorchLightning/pytorch-lightning instead of pip install pytorch-lightning. I suspect that the version of Lightning in the implementation of FasterAutoAugmentSearcher is must be upgraded.

Here's an example of the problem using the CIFAR100 dataset.

AutoAlbument Search

!pip install -U git+https://github.com/albumentations-team/autoalbument
!autoalbument-create --config-dir /content/ --task classification --num-classes 100
!autoalbument-search --config-dir /content

dataset.py

import torch.utils.data
from torchvision.datasets import CIFAR100

class Cifar10SearchDataset(torchvision.datasets.CIFAR100):
    def __init__(self, root="content/cifar100", train=True, download=True, transform=None):
        super().__init__(root=root, train=train, download=download, transform=transform)
    
    def __len__(self): return len(self.targets)

    def __getitem__(self, index):
        image, label = self.data[index], self.targets[index]

        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed["image"]

        return image, label

Search Output

_version: 2
task: classification
policy_model:
  task_factor: 0.1
  gp_factor: 10
  temperature: 0.05
  num_sub_policies: 40
  num_chunks: 4
  operation_count: 4
  operations:
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.ShiftRGB
    shift_r: true
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.ShiftRGB
    shift_g: true
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.ShiftRGB
    shift_b: true
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.RandomBrightness
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.RandomContrast
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.Solarize
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.HorizontalFlip
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.VerticalFlip
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.Rotate
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.ShiftX
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.ShiftY
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.Scale
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.CutoutFixedNumberOfHoles
  - _target_: autoalbument.faster_autoaugment.models.policy_operations.CutoutFixedSize
classification_model:
  _target_: autoalbument.faster_autoaugment.models.ClassificationModel
  num_classes: 100
  architecture: resnet18
  pretrained: false
data:
  dataset:
    _target_: dataset.SearchDataset
  input_dtype: uint8
  preprocessing: null
  normalization:
    mean:
    - 0.485
    - 0.456
    - 0.406
    std:
    - 0.229
    - 0.224
    - 0.225
  dataloader:
    _target_: torch.utils.data.DataLoader
    batch_size: 16
    shuffle: true
    num_workers: 8
    pin_memory: true
    drop_last: true
searcher:
  _target_: autoalbument.faster_autoaugment.search.FasterAutoAugmentSearcher
trainer:
  _target_: pytorch_lightning.Trainer
  gpus: 1
  benchmark: true
  max_epochs: 20
  resume_from_checkpoint: null
optim:
  main:
    _target_: torch.optim.Adam
    lr: 0.001
    betas:
    - 0
    - 0.999
  policy:
    _target_: torch.optim.Adam
    lr: 0.001
    betas:
    - 0
    - 0.999
callbacks:
- _target_: autoalbument.callbacks.MonitorAverageParameterChange
- _target_: autoalbument.callbacks.SavePolicy
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
  save_last: true
  dirpath: checkpoints
logger:
  _target_: pytorch_lightning.loggers.TensorBoardLogger
  save_dir: /content/outputs/2021-04-18/13-18-49/tensorboard_logs
seed: 42

Working directory: /content/outputs/2021-04-18/13-18-49
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/hydra/_internal/utils.py", line 544, in _locate
    import_module(mod)
  File "/usr/lib/python3.7/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1006, in _gcd_import
  File "<frozen importlib._bootstrap>", line 983, in _find_and_load
  File "<frozen importlib._bootstrap>", line 967, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 677, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 728, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/usr/local/lib/python3.7/dist-packages/autoalbument/faster_autoaugment/search.py", line 2, in <module>
    from pytorch_lightning import seed_everything
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/__init__.py", line 66, in <module>
    from pytorch_lightning import metrics
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/metrics/__init__.py", line 14, in <module>
    from pytorch_lightning.metrics.metric import Metric
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/metrics/metric.py", line 23, in <module>
    from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/metrics/utils.py", line 18, in <module>
    from pytorch_lightning.utilities import rank_zero_warn
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/__init__.py", line 24, in <module>
    from pytorch_lightning.utilities.apply_func import move_data_to_device
  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/apply_func.py", line 25, in <module>
    from torchtext.data import Batch
ImportError: cannot import name 'Batch' from 'torchtext.data' (/usr/local/lib/python3.7/dist-packages/torchtext/data/__init__.py)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/hydra/utils.py", line 61, in call
    type_or_callable = _locate(cls)
  File "/usr/local/lib/python3.7/dist-packages/hydra/_internal/utils.py", line 548, in _locate
    ) from e
ImportError: Encountered error: `cannot import name 'Batch' from 'torchtext.data' (/usr/local/lib/python3.7/dist-packages/torchtext/data/__init__.py)` when loading module 'autoalbument.faster_autoaugment.search.FasterAutoAugmentSearcher'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/autoalbument/cli/search.py", line 54, in main
    searcher = instantiate(cfg.searcher, cfg=cfg)
  File "/usr/local/lib/python3.7/dist-packages/hydra/utils.py", line 70, in call
    raise HydraException(f"Error calling '{cls}' : {e}") from e
hydra.errors.HydraException: Error calling 'autoalbument.faster_autoaugment.search.FasterAutoAugmentSearcher' : Encountered error: `cannot import name 'Batch' from 'torchtext.data' (/usr/local/lib/python3.7/dist-packages/torchtext/data/__init__.py)` when loading module 'autoalbument.faster_autoaugment.search.FasterAutoAugmentSearcher'

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

pip install torchtext==0.8.1 solved the issue for me.