louis-she / exhaustive-weighted-random-sampler

The missing distributed weighted random sampler for PyTorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ExhaustiveWeightedRandomSampler

run test codecov

ExhaustiveWeightedRandomSampler can exhaustively sample the indices with a specific weight over epochs.

Installation

pip install exhaustive-weighted-random-sampler

Usage & Comparasion

import torch
from torch.utils.data import WeightedRandomSampler
from exhaustive_weighted_random_sampler import ExhaustiveWeightedRandomSampler

sampler = WeightedRandomSampler([1, 1, 1, 1, 1, 1, 1, 1, 1, 10], num_samples=5)
for i in range(5):
    print(list(sampler))

"""
output:
[4, 3, 9, 3, 4]
[0, 5, 0, 9, 8]
[9, 9, 0, 9, 2]
[9, 9, 7, 9, 9]
[9, 9, 9, 9, 9]

explain: there are no 1 and 6, but 0 appears three times
"""

sampler = ExhaustiveWeightedRandomSampler([1, 1, 1, 1, 1, 1, 1, 1, 1, 10], num_samples=5)
for i in range(5):
    print(list(sampler))

"""
output:
[4, 6, 9, 9, 9]
[1, 0, 9, 9, 5]
[9, 7, 3, 8, 9]
[9, 2, 1, 9, 9]
[8, 9, 7, 3, 2]

explain: all the 0 to 8 appears in the yield results.
"""

Use in DDP

It can be used in DDP if pytorch-ignite has been installed.

from ignite.distributed import DistributedProxySampler
from torch.utils.data import DataLoader

dataset = ...

sampler = DistributedProxySampler(
    ExhaustiveWeightedRandomSampler(weights, num_samples=10000)
)
loader = DataLoader(dataset, sampler=sampler, ...)

About

The missing distributed weighted random sampler for PyTorch

License:MIT License


Languages

Language:Python 100.0%