In pytorchvideo.data.labeled_video_dataset, is transform is wrongly defined? or am I using it wrongly?
Tenglon opened this issue · comments
In pytorchvideo.data.labeled_video_dataset,
transform is applied as:
class LabeledVideoDataset(torch.utils.data.IterableDataset):
def __next__(self) -> dict:
....
....
if self._transform is not None:
sample_dict = self._transform(sample_dict) # This line is not correct?
where transform requires a dictionary as input,
But in pytorchvideo.transforms
Let's say, UniformTemporalSubsample, it takes a tensor as the input, which is inconsistent with a dictionary, am I missing something?
I figured out myself.
The transform should be defined as
self._TRANSFORM = torchvision.transforms.Compose([
pytorchvideo.transforms.ApplyTransformToKey(
key="video",
transform=torchvision.transforms.Compose([
pytorchvideo.transforms.UniformTemporalSubsample(self._CLIP_FRAMES),
pytorchvideo.transforms.Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
pytorchvideo.transforms.RandomShortSideScale(min_size=256, max_size=320),
torchvision.transforms.RandomCrop(224),
torchvision.transforms.RandomHorizontalFlip(p=0.5),
])
)
])
, where ApplyTransformToKey takes a dictionary as input
so overall, it takes a dictionary as input.