Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging

Home Page:https://monai.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RandWeightedCropd in batch

7oud opened this issue · comments

commented

Describe the bug
using RandWeightedCropd to random crop patches from images with different size, program crashed due toweight_map with differnet shape.

To Reproduce

  1. after spacing, the shape of image 1 is (1, 258, 358, 358) and image 2 is (1, 245, 424, 424)
  2. load weight maps, which have the same size as the corresponding image, (1, 258, 358, 358) and (1, 245, 424, 424)
  3. use RandWeightedCropd to crop patches with fixed size (224, 224, 224)
  4. use DataLoader(ds, batch_size=2, collate_fn=pad_list_data_collate), the program crashed due the different size of weight_map
  5. if using RandSpatialCropSamplesd without weight map, that's OK
def get_transforms():
    transforms = [
        LoadImaged(keys=['image']),
        EnsureChannelFirstd(keys=['image']),
        Orientationd(keys=['image'], axcodes='SPL'),
        Spacingd(keys=['image'], pixdim=[0.5, 0.5, 0.5], mode=["bilinear"]),
        LoadImaged(keys=['wgtmap']),
        EnsureChannelFirstd(keys=['wgtmap']),
        RandWeightedCropd(
            keys=['image'],
            spatial_size=(224, 224, 224),
            num_samples=1,
            w_key='wgtmap',
        ),
        # RandSpatialCropSamplesd(
        #     keys=['image'],
        #     roi_size=(224, 224, 224),
        #     num_samples=1,
        # ),
    ]

    return Compose(transforms)

data1 = {
    'image': os.path.join(root, 'image035.nii.gz'),
    'wgtmap': os.path.join(root, 'image035-wgt.npy'),
}
data2 = {
    'image': os.path.join(root, '10106129_img-orig.nii.gz'),
    'wgtmap': os.path.join(root, '10106129_img-orig-wgt.npy'),
}

trans = get_transforms()
ds = Dataset(data=[data1, data2], transform=trans)
dl = DataLoader(ds, batch_size=2, num_workers=2, collate_fn=pad_list_data_collate)

for i, batch_data in enumerate(dl):
    inputs = batch_data["image"]
    print(inputs.shape)
    wgtmap = batch_data["wgtmap"]
    print(wgtmap.shape)

OUTPUT

collate dict key "image" out of 2 keys

collate/stack a list of tensors
collate dict key "wgtmap" out of 2 keys
collate/stack a list of tensors
E: stack expects each tensor to be equal size, but got [1, 258, 358, 358] at entry 0 and [1, 245, 424, 424] at entry 1, shape [torch.Size([1, 258, 358, 358]), torch.Size([1, 245, 424, 424])] in collate([metatensor([[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],

Traceback (most recent call last):
File "/Users/z/repo/github/7oud/tst_py_store/monai_store/transform copy.py", line 83, in
main()
File "/Users/xxx/repo/github/7oud/tst_py_store/monai_store/transform copy.py", line 68, in main
for i, batch_data in enumerate(dl):
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 628, in next
data = self._next_data()
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1333, in _next_data
return self._process_data(data)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1359, in _process_data
data.reraise()
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/_utils.py", line 543, in reraise
raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/data/utils.py", line 516, in list_data_collate
ret = collate_fn(data)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
return collate(batch, collate_fn_map=default_collate_fn_map)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 128, in collate
return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 128, in
return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 120, in collate
return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/data/utils.py", line 458, in collate_meta_tensor_fn
collated = collate_fn(batch) # type: ignore
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 163, in collate_tensor_fn
return torch.stack(batch, 0, out=out)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/data/meta_tensor.py", line 282, in torch_function
ret = super().torch_function(func, types, args, kwargs)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/_tensor.py", line 1279, in torch_function
ret = func(*args, **kwargs)
RuntimeError: stack expects each tensor to be equal size, but got [1, 258, 358, 358] at entry 0 and [1, 245, 424, 424] at entry 1

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
data = fetcher.fetch(index)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 61, in fetch
return self.collate_fn(data)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/data/utils.py", line 696, in pad_list_data_collate
return PadListDataCollate(method=method, mode=mode, **kwargs)(batch)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/transforms/croppad/batch.py", line 114, in call
return list_data_collate(batch)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/data/utils.py", line 529, in list_data_collate
raise RuntimeError(re_str) from re
RuntimeError: stack expects each tensor to be equal size, but got [1, 258, 358, 358] at entry 0 and [1, 245, 424, 424] at entry 1

MONAI hint: if your transforms intentionally create images of different shapes, creating your DataLoader with collate_fn=pad_list_data_collate might solve this problem (check its documentation).

Environment
MONAI version: 1.3.1
Numpy version: 1.23.5
Pytorch version: 1.13.1
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 96bfda0
MONAI file: /Users//opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/init.py