Lightning-Universe / lightning-flash

Your PyTorch AI Factory - Flash enables you to easily configure and run complex AI recipes for over 15 tasks across 7 data domains

Home Page:https://lightning-flash.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`ImageClassificationDataFrameInput` object has no attribute `target_formatter`

giwook-david opened this issue Β· comments

πŸ› Bug

When I tried to make datamodule from dataframe, I get this error mesage

Input In [13], in <cell line: 1>()
----> 1 datamodule = ImageClassificationData.from_data_frame(input_field="img_local_path", target_field="labels", train_data_frame=df, target_formatter=CommaDelimitedMultiLabelTargetFormatter(labels=labels))

File ~/.local/share/virtualenvs/yesplz-model-zoo-eXaYZlXU/lib/python3.10/site-packages/flash/image/classification/data.py:698, in ImageClassificationData.from_data_frame(cls, input_field, target_fields, train_data_frame, train_images_root, train_resolver, val_data_frame, val_images_root, val_resolver, test_data_frame, test_images_root, test_resolver, predict_data_frame, predict_images_root, predict_resolver, target_formatter, input_cls, transform, transform_kwargs, **data_module_kwargs)
    695 test_data = (test_data_frame, input_field, target_fields, test_images_root, test_resolver)
    696 predict_data = (predict_data_frame, input_field, None, predict_images_root, predict_resolver)
--> 698 train_input = input_cls(RunningStage.TRAINING, *train_data, **ds_kw)
    699 ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)
    701 return cls(
    702     train_input,
    703     input_cls(RunningStage.VALIDATING, *val_data, **ds_kw),
   (...)
    708     **data_module_kwargs,
    709 )

File ~/.local/share/virtualenvs/yesplz-model-zoo-eXaYZlXU/lib/python3.10/site-packages/flash/core/data/io/input.py:135, in _wrap_init.<locals>.wrapper(self, *args, **kwargs)
    133 @functools.wraps(fn)
    134 def wrapper(self, *args, **kwargs):
--> 135     fn(self, *args, **kwargs)
    136     _validate_input(self)

File ~/.local/share/virtualenvs/yesplz-model-zoo-eXaYZlXU/lib/python3.10/site-packages/flash/core/data/io/input.py:176, in InputBase.__init__(self, running_stage, *args, **kwargs)
    174 self.data = None
    175 if len(args) >= 1 and args[0] is not None:
--> 176     self.data = getattr(self, f"{_STAGES_PREFIX[running_stage]}_load_data")(*args, **kwargs)

File ~/.local/share/virtualenvs/yesplz-model-zoo-eXaYZlXU/lib/python3.10/site-packages/flash/core/data/io/input.py:212, in InputBase.train_load_data(self, *args, **kwargs)
    205 def train_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterable]:
    206     """Override the ``train_load_data`` hook with data loading logic that is only required during training.
    207
    208     Args:
    209         *args: Any arguments that the input requires.
    210         **kwargs: Any additional keyword arguments that the input requires.
    211     """
--> 212     return self.load_data(*args, **kwargs)

File ~/.local/share/virtualenvs/yesplz-model-zoo-eXaYZlXU/lib/python3.10/site-packages/flash/image/classification/input.py:163, in ImageClassificationDataFrameInput.load_data(self, data_frame, input_key, target_keys, root, resolver, target_formatter)
    158 result = super().load_data(files, targets, target_formatter=target_formatter)
    160 # If we had binary multi-class targets then we also know the labels (column names)
    161 if (
    162     self.training
--> 163     and isinstance(self.target_formatter, MultiBinaryTargetFormatter)
    164     and isinstance(target_keys, List)
    165 ):
    166     self.labels = target_keys
    168 return result

AttributeError: 'ImageClassificationDataFrameInput' object has no attribute 'target_formatter'

and I checked ImageClassificationDataFrameInput from here

from 160th line, you can see that there is "self.target_formatter" which is never mentioned in input.py except ImageClassificationDataFrameInput.
I think this part should be changed from "self.target_formatter" to "target_formatter"

# If we had binary multi-class targets then we also know the labels (column names)
if (
    self.training
    and isinstance(self.target_formatter, MultiBinaryTargetFormatter)
    and isinstance(target_keys, List)
):
    self.labels = target_keys

@giwook-david what version are you using?