ERROR label_to_count not callable
samsja opened this issue · comments
Hi !
I noticed that they are some bugs introduce with the last commit ad50e22
Step to reproduce
`
import torch
from torchsampler import ImbalancedDatasetSampler
mnist = torchvision.datasets.MNIST('.', train=True, download=True, transform=transform)
train_loader_b = torch.utils.data.DataLoader(
mnist,
sampler=ImbalancedDatasetSampler(mnist),
batch_size=args.batch_size,
)
`
`
TypeError Traceback (most recent call last)
in
1 train_loader= torch.utils.data.DataLoader(
2 mnist,
----> 3 sampler=ImbalancedDatasetSampler(mnist),
4 batch_size=args.batch_size,
5 )
~/.local/lib/python3.8/site-packages/torchsampler/imbalanced.py in init(self, dataset, indices, num_samples, callback_get_label)
34 label_to_count = df["label"].value_counts()
35
---> 36 weights = 1.0 / label_to_count(df["label"])
37
38 self.weights = torch.DoubleTensor(weights)
TypeError: 'Series' object is not callable
`
I just think that label_to_count is now pandas series and can't be called.
Any idea how to fix it ? ( I will give it a try soon)
Hello! First, sorry for the inconvenience of my updated code.
I had smaller mistakes, so I will re-PR!