RandWeightedCropd in batch
7oud opened this issue · comments
Describe the bug
using RandWeightedCropd
to random crop patches from images with different size, program crashed due toweight_map
with differnet shape.
To Reproduce
- after
spacing
, the shape ofimage 1
is (1, 258, 358, 358) andimage 2
is (1, 245, 424, 424) - load weight maps, which have the same size as the corresponding image, (1, 258, 358, 358) and (1, 245, 424, 424)
- use
RandWeightedCropd
to crop patches with fixed size (224, 224, 224) - use
DataLoader(ds, batch_size=2, collate_fn=pad_list_data_collate)
, the program crashed due the different size ofweight_map
- if using
RandSpatialCropSamplesd
without weight map, that's OK
def get_transforms():
transforms = [
LoadImaged(keys=['image']),
EnsureChannelFirstd(keys=['image']),
Orientationd(keys=['image'], axcodes='SPL'),
Spacingd(keys=['image'], pixdim=[0.5, 0.5, 0.5], mode=["bilinear"]),
LoadImaged(keys=['wgtmap']),
EnsureChannelFirstd(keys=['wgtmap']),
RandWeightedCropd(
keys=['image'],
spatial_size=(224, 224, 224),
num_samples=1,
w_key='wgtmap',
),
# RandSpatialCropSamplesd(
# keys=['image'],
# roi_size=(224, 224, 224),
# num_samples=1,
# ),
]
return Compose(transforms)
data1 = {
'image': os.path.join(root, 'image035.nii.gz'),
'wgtmap': os.path.join(root, 'image035-wgt.npy'),
}
data2 = {
'image': os.path.join(root, '10106129_img-orig.nii.gz'),
'wgtmap': os.path.join(root, '10106129_img-orig-wgt.npy'),
}
trans = get_transforms()
ds = Dataset(data=[data1, data2], transform=trans)
dl = DataLoader(ds, batch_size=2, num_workers=2, collate_fn=pad_list_data_collate)
for i, batch_data in enumerate(dl):
inputs = batch_data["image"]
print(inputs.shape)
wgtmap = batch_data["wgtmap"]
print(wgtmap.shape)
OUTPUT
collate dict key "image" out of 2 keys
collate/stack a list of tensors
collate dict key "wgtmap" out of 2 keys
collate/stack a list of tensors
E: stack expects each tensor to be equal size, but got [1, 258, 358, 358] at entry 0 and [1, 245, 424, 424] at entry 1, shape [torch.Size([1, 258, 358, 358]), torch.Size([1, 245, 424, 424])] in collate([metatensor([[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
Traceback (most recent call last):
File "/Users/z/repo/github/7oud/tst_py_store/monai_store/transform copy.py", line 83, in
main()
File "/Users/xxx/repo/github/7oud/tst_py_store/monai_store/transform copy.py", line 68, in main
for i, batch_data in enumerate(dl):
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 628, in next
data = self._next_data()
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1333, in _next_data
return self._process_data(data)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1359, in _process_data
data.reraise()
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/_utils.py", line 543, in reraise
raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/data/utils.py", line 516, in list_data_collate
ret = collate_fn(data)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
return collate(batch, collate_fn_map=default_collate_fn_map)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 128, in collate
return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 128, in
return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 120, in collate
return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/data/utils.py", line 458, in collate_meta_tensor_fn
collated = collate_fn(batch) # type: ignore
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 163, in collate_tensor_fn
return torch.stack(batch, 0, out=out)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/data/meta_tensor.py", line 282, in torch_function
ret = super().torch_function(func, types, args, kwargs)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/_tensor.py", line 1279, in torch_function
ret = func(*args, **kwargs)
RuntimeError: stack expects each tensor to be equal size, but got [1, 258, 358, 358] at entry 0 and [1, 245, 424, 424] at entry 1
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
data = fetcher.fetch(index)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 61, in fetch
return self.collate_fn(data)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/data/utils.py", line 696, in pad_list_data_collate
return PadListDataCollate(method=method, mode=mode, **kwargs)(batch)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/transforms/croppad/batch.py", line 114, in call
return list_data_collate(batch)
File "/Users/xxx/opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/data/utils.py", line 529, in list_data_collate
raise RuntimeError(re_str) from re
RuntimeError: stack expects each tensor to be equal size, but got [1, 258, 358, 358] at entry 0 and [1, 245, 424, 424] at entry 1
MONAI hint: if your transforms intentionally create images of different shapes, creating your DataLoader
with collate_fn=pad_list_data_collate
might solve this problem (check its documentation).
Environment
MONAI version: 1.3.1
Numpy version: 1.23.5
Pytorch version: 1.13.1
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 96bfda0
MONAI file: /Users//opt/miniconda3/envs/monai/lib/python3.9/site-packages/monai/init.py