ufoym / imbalanced-dataset-sampler

A (PyTorch) imbalanced dataset sampler for oversampling low frequent classes and undersampling high frequent ones.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ERROR label_to_count not callable

samsja opened this issue · comments

Hi !

I noticed that they are some bugs introduce with the last commit ad50e22

Step to reproduce

`
import torch
from torchsampler import ImbalancedDatasetSampler

mnist = torchvision.datasets.MNIST('.', train=True, download=True, transform=transform)
train_loader_b = torch.utils.data.DataLoader(
mnist,
sampler=ImbalancedDatasetSampler(mnist),
batch_size=args.batch_size,
)
`

`
TypeError Traceback (most recent call last)
in
1 train_loader= torch.utils.data.DataLoader(
2 mnist,
----> 3 sampler=ImbalancedDatasetSampler(mnist),
4 batch_size=args.batch_size,
5 )

~/.local/lib/python3.8/site-packages/torchsampler/imbalanced.py in init(self, dataset, indices, num_samples, callback_get_label)
34 label_to_count = df["label"].value_counts()
35
---> 36 weights = 1.0 / label_to_count(df["label"])
37
38 self.weights = torch.DoubleTensor(weights)

TypeError: 'Series' object is not callable
`

I just think that label_to_count is now pandas series and can't be called.

Any idea how to fix it ? ( I will give it a try soon)

Hello! First, sorry for the inconvenience of my updated code.

I had smaller mistakes, so I will re-PR!