ImportError within FasterAutoAugmentSearcher
adam-mehdi opened this issue · comments
Hello,
AutoAlbument's FasterAutoAugmentSearcher
is yielding an import error cannot import name 'Batch' from 'torchtext.data'
when searching for a policy. This bug recently occurred in PyTorch Lightning in general, and it is fixed by importing Lightning from github: pip install git+https://github.com/PyTorchLightning/pytorch-lightning
instead of pip install pytorch-lightning
. I suspect that the version of Lightning in the implementation of FasterAutoAugmentSearcher
is must be upgraded.
Here's an example of the problem using the CIFAR100 dataset.
AutoAlbument Search
!pip install -U git+https://github.com/albumentations-team/autoalbument
!autoalbument-create --config-dir /content/ --task classification --num-classes 100
!autoalbument-search --config-dir /content
dataset.py
import torch.utils.data
from torchvision.datasets import CIFAR100
class Cifar10SearchDataset(torchvision.datasets.CIFAR100):
def __init__(self, root="content/cifar100", train=True, download=True, transform=None):
super().__init__(root=root, train=train, download=download, transform=transform)
def __len__(self): return len(self.targets)
def __getitem__(self, index):
image, label = self.data[index], self.targets[index]
if self.transform is not None:
transformed = self.transform(image=image)
image = transformed["image"]
return image, label
Search Output
_version: 2
task: classification
policy_model:
task_factor: 0.1
gp_factor: 10
temperature: 0.05
num_sub_policies: 40
num_chunks: 4
operation_count: 4
operations:
- _target_: autoalbument.faster_autoaugment.models.policy_operations.ShiftRGB
shift_r: true
- _target_: autoalbument.faster_autoaugment.models.policy_operations.ShiftRGB
shift_g: true
- _target_: autoalbument.faster_autoaugment.models.policy_operations.ShiftRGB
shift_b: true
- _target_: autoalbument.faster_autoaugment.models.policy_operations.RandomBrightness
- _target_: autoalbument.faster_autoaugment.models.policy_operations.RandomContrast
- _target_: autoalbument.faster_autoaugment.models.policy_operations.Solarize
- _target_: autoalbument.faster_autoaugment.models.policy_operations.HorizontalFlip
- _target_: autoalbument.faster_autoaugment.models.policy_operations.VerticalFlip
- _target_: autoalbument.faster_autoaugment.models.policy_operations.Rotate
- _target_: autoalbument.faster_autoaugment.models.policy_operations.ShiftX
- _target_: autoalbument.faster_autoaugment.models.policy_operations.ShiftY
- _target_: autoalbument.faster_autoaugment.models.policy_operations.Scale
- _target_: autoalbument.faster_autoaugment.models.policy_operations.CutoutFixedNumberOfHoles
- _target_: autoalbument.faster_autoaugment.models.policy_operations.CutoutFixedSize
classification_model:
_target_: autoalbument.faster_autoaugment.models.ClassificationModel
num_classes: 100
architecture: resnet18
pretrained: false
data:
dataset:
_target_: dataset.SearchDataset
input_dtype: uint8
preprocessing: null
normalization:
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
dataloader:
_target_: torch.utils.data.DataLoader
batch_size: 16
shuffle: true
num_workers: 8
pin_memory: true
drop_last: true
searcher:
_target_: autoalbument.faster_autoaugment.search.FasterAutoAugmentSearcher
trainer:
_target_: pytorch_lightning.Trainer
gpus: 1
benchmark: true
max_epochs: 20
resume_from_checkpoint: null
optim:
main:
_target_: torch.optim.Adam
lr: 0.001
betas:
- 0
- 0.999
policy:
_target_: torch.optim.Adam
lr: 0.001
betas:
- 0
- 0.999
callbacks:
- _target_: autoalbument.callbacks.MonitorAverageParameterChange
- _target_: autoalbument.callbacks.SavePolicy
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
save_last: true
dirpath: checkpoints
logger:
_target_: pytorch_lightning.loggers.TensorBoardLogger
save_dir: /content/outputs/2021-04-18/13-18-49/tensorboard_logs
seed: 42
Working directory: /content/outputs/2021-04-18/13-18-49
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/hydra/_internal/utils.py", line 544, in _locate
import_module(mod)
File "/usr/lib/python3.7/importlib/__init__.py", line 127, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
File "<frozen importlib._bootstrap>", line 1006, in _gcd_import
File "<frozen importlib._bootstrap>", line 983, in _find_and_load
File "<frozen importlib._bootstrap>", line 967, in _find_and_load_unlocked
File "<frozen importlib._bootstrap>", line 677, in _load_unlocked
File "<frozen importlib._bootstrap_external>", line 728, in exec_module
File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
File "/usr/local/lib/python3.7/dist-packages/autoalbument/faster_autoaugment/search.py", line 2, in <module>
from pytorch_lightning import seed_everything
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/__init__.py", line 66, in <module>
from pytorch_lightning import metrics
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/metrics/__init__.py", line 14, in <module>
from pytorch_lightning.metrics.metric import Metric
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/metrics/metric.py", line 23, in <module>
from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/metrics/utils.py", line 18, in <module>
from pytorch_lightning.utilities import rank_zero_warn
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/__init__.py", line 24, in <module>
from pytorch_lightning.utilities.apply_func import move_data_to_device
File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/apply_func.py", line 25, in <module>
from torchtext.data import Batch
ImportError: cannot import name 'Batch' from 'torchtext.data' (/usr/local/lib/python3.7/dist-packages/torchtext/data/__init__.py)
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/hydra/utils.py", line 61, in call
type_or_callable = _locate(cls)
File "/usr/local/lib/python3.7/dist-packages/hydra/_internal/utils.py", line 548, in _locate
) from e
ImportError: Encountered error: `cannot import name 'Batch' from 'torchtext.data' (/usr/local/lib/python3.7/dist-packages/torchtext/data/__init__.py)` when loading module 'autoalbument.faster_autoaugment.search.FasterAutoAugmentSearcher'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/autoalbument/cli/search.py", line 54, in main
searcher = instantiate(cfg.searcher, cfg=cfg)
File "/usr/local/lib/python3.7/dist-packages/hydra/utils.py", line 70, in call
raise HydraException(f"Error calling '{cls}' : {e}") from e
hydra.errors.HydraException: Error calling 'autoalbument.faster_autoaugment.search.FasterAutoAugmentSearcher' : Encountered error: `cannot import name 'Batch' from 'torchtext.data' (/usr/local/lib/python3.7/dist-packages/torchtext/data/__init__.py)` when loading module 'autoalbument.faster_autoaugment.search.FasterAutoAugmentSearcher'
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
pip install torchtext==0.8.1 solved the issue for me.