Validate class names not as string
fmigneault opened this issue · comments
When passing a dict[int, int]
, thelper will ingest it without raising anything (ClassNamesHandler
will simply store the dict), but this will at a later time raise during the dataloader execution that specifically checks for dict[str, int]
.
There should be a pre-check / warning immediately at Task
creation that raises the invalid format to avoid much later assert. It is generally harder to track where the erroneous value came from the later it gets raised.
In the special case of dict[int, int]
, thelper could also infer automatically convert the labels since it is a pretty common scenario to have IDs instead of plain names.
Here is an example of Task creation working with dict[int, int]
and failing right after at dataloader instantiation.
[2020-10-14 15:18:02,554 - thelper.data.utils] DEBUG : loading dataset 'deepglobe' configuration...
[2020-10-14 15:18:04,161 - thelper.data.utils] WARNING : 'task' field detected in dataset 'deepglobe' config; dataset's default task will be ignored
[2020-10-14 15:18:04,163 - thelper.data.utils] INFO : parsed dataset: ginmodelrepo.util.BatchTestPatchesBaseSegDatasetLoader(transforms=thelper.transforms.composers.Compose(transforms=[
thelper.transforms.wrappers.TransformWrapper(operation=thelper.transforms.operations.SelectChannels(channels={0: 0, 1: 1, 2: 2}), params={}, probability=1, convert_pil=False, target_keys=['data'], linked_fate=True),
thelper.transforms.wrappers.TransformWrapper(operation=thelper.transforms.operations.CenterCrop(size=(128, 128), bordertype=0, borderval=0), params={}, probability=1, convert_pil=False, target_keys=None, linked_fate=True),
thelper.transforms.wrappers.TransformWrapper(operation=thelper.transforms.operations.NormalizeMinMax(min=[0.], max=[255.], out_type=<class 'numpy.float32'>), params={}, probability=1, convert_pil=False, target_keys=['data'], linked_fate=True),
thelper.transforms.wrappers.TransformWrapper(operation=thelper.transforms.operations.NormalizeZeroMeanUnitVar(mean=[0.485 0.456 0.406], std=[0.229 0.224 0.225], out_type=<class 'numpy.float32'>), params={}, probability=1, convert_pil=False, target_keys=['data'], linked_fate=True),
thelper.transforms.wrappers.TransformWrapper(operation=thelper.transforms.operations.Transpose(axes=[2 0 1]), params={}, probability=1, convert_pil=False, target_keys=['data'], linked_fate=True)
]), deepcopy=False)
[2020-10-14 15:18:04,164 - thelper.data.utils] INFO : task info: thelper.tasks.segm.Segmentation(class_names={235: 0, 231: 1, 240: 2, 203: 3, 255: 4, 223: 5, 242: 6, 83: 7, 63: 8, 161: 9}, input_key='data', label_map_key='mask', meta_keys=[], dontcare=-1, color_map={})
[2020-10-14 15:18:04,164 - thelper.data.utils] DEBUG : splitting datasets and creating loaders...
[2020-10-14 15:18:04,164 - thelper.data.loaders] INFO : splitting datasets with parsed sizes = {'deepglobe': 631}
[2020-10-14 15:18:04,171 - thelper.data.loaders] INFO : initialized loaders with batch counts:
train = 32
valid = 8
Traceback (most recent call last):
File "/usr/local/bin/thelper", line 8, in <module>
sys.exit(main())
File "/usr/local/lib/python3.6/dist-packages/thelper/cli.py", line 604, in main
resume_session(ckptdata, save_dir, config=override_config, eval_only=args.eval_only, task_compat=args.task_compat)
File "/usr/local/lib/python3.6/dist-packages/thelper/cli.py", line 113, in resume_session
old_task = thelper.tasks.create_task(ckptdata["task"]) if isinstance(ckptdata["task"], str) else ckptdata["task"]
File "/usr/local/lib/python3.6/dist-packages/thelper/tasks/utils.py", line 58, in create_task
task = eval(config)
File "<string>", line 1, in <module>
File "/usr/local/lib/python3.6/dist-packages/thelper/tasks/segm.py", line 62, in __init__
ClassNamesHandler.__init__(self, class_names=class_names)
File "/usr/local/lib/python3.6/dist-packages/thelper/ifaces.py", line 92, in __init__
self.class_names = class_names
File "/usr/local/lib/python3.6/dist-packages/thelper/ifaces.py", line 128, in class_names
assert all([isinstance(name, str) for name in class_names]), "all classes must be named with strings"
AssertionError: all classes must be named with strings