plstcharles / thelper

Training framework & tools for PyTorch-based machine learning projects.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Validate class names not as string

fmigneault opened this issue · comments

When passing a dict[int, int], thelper will ingest it without raising anything (ClassNamesHandler will simply store the dict), but this will at a later time raise during the dataloader execution that specifically checks for dict[str, int].

There should be a pre-check / warning immediately at Task creation that raises the invalid format to avoid much later assert. It is generally harder to track where the erroneous value came from the later it gets raised.

In the special case of dict[int, int], thelper could also infer automatically convert the labels since it is a pretty common scenario to have IDs instead of plain names.

Here is an example of Task creation working with dict[int, int] and failing right after at dataloader instantiation.

[2020-10-14 15:18:02,554 - thelper.data.utils] DEBUG : loading dataset 'deepglobe' configuration...
[2020-10-14 15:18:04,161 - thelper.data.utils] WARNING : 'task' field detected in dataset 'deepglobe' config; dataset's default task will be ignored
[2020-10-14 15:18:04,163 - thelper.data.utils] INFO : parsed dataset: ginmodelrepo.util.BatchTestPatchesBaseSegDatasetLoader(transforms=thelper.transforms.composers.Compose(transforms=[
	thelper.transforms.wrappers.TransformWrapper(operation=thelper.transforms.operations.SelectChannels(channels={0: 0, 1: 1, 2: 2}), params={}, probability=1, convert_pil=False, target_keys=['data'], linked_fate=True),
	thelper.transforms.wrappers.TransformWrapper(operation=thelper.transforms.operations.CenterCrop(size=(128, 128), bordertype=0, borderval=0), params={}, probability=1, convert_pil=False, target_keys=None, linked_fate=True),
	thelper.transforms.wrappers.TransformWrapper(operation=thelper.transforms.operations.NormalizeMinMax(min=[0.], max=[255.], out_type=<class 'numpy.float32'>), params={}, probability=1, convert_pil=False, target_keys=['data'], linked_fate=True),
	thelper.transforms.wrappers.TransformWrapper(operation=thelper.transforms.operations.NormalizeZeroMeanUnitVar(mean=[0.485 0.456 0.406], std=[0.229 0.224 0.225], out_type=<class 'numpy.float32'>), params={}, probability=1, convert_pil=False, target_keys=['data'], linked_fate=True),
	thelper.transforms.wrappers.TransformWrapper(operation=thelper.transforms.operations.Transpose(axes=[2 0 1]), params={}, probability=1, convert_pil=False, target_keys=['data'], linked_fate=True)
]), deepcopy=False)
[2020-10-14 15:18:04,164 - thelper.data.utils] INFO : task info: thelper.tasks.segm.Segmentation(class_names={235: 0, 231: 1, 240: 2, 203: 3, 255: 4, 223: 5, 242: 6, 83: 7, 63: 8, 161: 9}, input_key='data', label_map_key='mask', meta_keys=[], dontcare=-1, color_map={})
[2020-10-14 15:18:04,164 - thelper.data.utils] DEBUG : splitting datasets and creating loaders...
[2020-10-14 15:18:04,164 - thelper.data.loaders] INFO : splitting datasets with parsed sizes = {'deepglobe': 631}
[2020-10-14 15:18:04,171 - thelper.data.loaders] INFO : initialized loaders with batch counts:
	train = 32
	valid = 8
Traceback (most recent call last):
  File "/usr/local/bin/thelper", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.6/dist-packages/thelper/cli.py", line 604, in main
    resume_session(ckptdata, save_dir, config=override_config, eval_only=args.eval_only, task_compat=args.task_compat)
  File "/usr/local/lib/python3.6/dist-packages/thelper/cli.py", line 113, in resume_session
    old_task = thelper.tasks.create_task(ckptdata["task"]) if isinstance(ckptdata["task"], str) else ckptdata["task"]
  File "/usr/local/lib/python3.6/dist-packages/thelper/tasks/utils.py", line 58, in create_task
    task = eval(config)
  File "<string>", line 1, in <module>
  File "/usr/local/lib/python3.6/dist-packages/thelper/tasks/segm.py", line 62, in __init__
    ClassNamesHandler.__init__(self, class_names=class_names)
  File "/usr/local/lib/python3.6/dist-packages/thelper/ifaces.py", line 92, in __init__
    self.class_names = class_names
  File "/usr/local/lib/python3.6/dist-packages/thelper/ifaces.py", line 128, in class_names
    assert all([isinstance(name, str) for name in class_names]), "all classes must be named with strings"
AssertionError: all classes must be named with strings