microsoft / infinibatch

Efficient, check-pointed data loading for deep learning with massive data sets.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Implement wrapper iterator that inherits from PyTorch's IterableDataset

gmyr opened this issue · comments

We might want to implement an Iterator that inherits from PyTorch's IterableDataset to have a direct interface to PyTorch's data loader functionality.

Here is some prototype code that we had earlier in this direction.

class IterableCheckpointedDataset(torch.utils.data.IterableDataset):
    """
    Wraps a CheckpointableIterator into a PyTorch IterableDataset, which is recognized by its type by
    PyTorch's DataLoader class.
    """
    def __init__(self, source: CheckpointableIterator):
        super().__init__()
        self._source = source

    def __iter__(self):  # this is called in the forked clone
        worker_info = torch.utils.data.get_worker_info()
        assert worker_info is None or worker_info.num_workers == 1  # not supported since we can't get at the checkpoint for each worker
        return iter(self._source)


class IterableChunkedDataset(torch.utils.data.IterableDataset):
    def __init__(self, paths: Union[str, Iterable[str]], shuffle: bool=True, buffer_size: int=2**20, transform=None, seed: int=None, world_size: int=1, rank: int=0, num_workers_per_rank: int=1):
        super().__init__()
        self.rank = rank
        self.num_workers_per_rank = num_workers_per_rank
        # instance_rank is set assuming that num_workers_per_rank = 1 and adapted dynamically in __iter__
        self.dataset = chunked_dataset_iterator(paths, shuffle=shuffle, buffer_size=buffer_size, transform=transform, seed=seed, num_instances=world_size*num_workers_per_rank, instance_rank=rank)

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading
            self.dataset._instance_rank = self.rank
        else:
            assert worker_info.num_workers == self.num_workers_per_rank
            self.dataset._instance_rank = self.rank * self.num_workers_per_rank + worker_info.id
        return iter(self.dataset)
commented

Found an issue in the above example. The iter(self.dataset) call doesn't reset the iterator. So if a non-infinite validation iterator is used up it the __iter__ call returns an empty iterator. I guess we have to use set_state for __iter__ calls for validation datasets.

And in case of training we can assume it's infinite and just return without setting the state.