ExhaustiveWeightedRandomSampler
ExhaustiveWeightedRandomSampler can exhaustively sample the indices with a specific weight over epochs.
Installation
pip install exhaustive-weighted-random-sampler
Usage & Comparasion
import torch
from torch.utils.data import WeightedRandomSampler
from exhaustive_weighted_random_sampler import ExhaustiveWeightedRandomSampler
sampler = WeightedRandomSampler([1, 1, 1, 1, 1, 1, 1, 1, 1, 10], num_samples=5)
for i in range(5):
print(list(sampler))
"""
output:
[4, 3, 9, 3, 4]
[0, 5, 0, 9, 8]
[9, 9, 0, 9, 2]
[9, 9, 7, 9, 9]
[9, 9, 9, 9, 9]
explain: there are no 1 and 6, but 0 appears three times
"""
sampler = ExhaustiveWeightedRandomSampler([1, 1, 1, 1, 1, 1, 1, 1, 1, 10], num_samples=5)
for i in range(5):
print(list(sampler))
"""
output:
[4, 6, 9, 9, 9]
[1, 0, 9, 9, 5]
[9, 7, 3, 8, 9]
[9, 2, 1, 9, 9]
[8, 9, 7, 3, 2]
explain: all the 0 to 8 appears in the yield results.
"""
Use in DDP
It can be used in DDP
if pytorch-ignite has been installed.
from ignite.distributed import DistributedProxySampler
from torch.utils.data import DataLoader
dataset = ...
sampler = DistributedProxySampler(
ExhaustiveWeightedRandomSampler(weights, num_samples=10000)
)
loader = DataLoader(dataset, sampler=sampler, ...)