NotImplemented Error while running ImbalancedDatasetSampler
aryamansriram opened this issue · comments
I followed the steps exactly according to the readme file. Yet I am getting a notimplemented error. There's no explanation for the error as well.
Here's my code:
`from torchvision import transforms
from torchsampler import ImbalancedDatasetSampler
batch_size = 128
val_split = 0.2
shuffle_dataset=True
random_seed=42
dataset_size = len(melanoma_dataset)
indices = list(range(dataset_size))
split = int(np.floor(val_split * dataset_size))
if shuffle_dataset :
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, test_indices = indices[split:], indices[:split]
train_loader = torch.utils.data.DataLoader(melanoma_dataset,batch_size=batch_size,sampler=ImbalancedDatasetSampler(melanoma_dataset))
test_loader = torch.utils.data.DataLoader(melanoma_dataset,batch_size=batch_size,sampler=test_sampler)`
Having the same issue - no detailed error message, just a NotImplementedError
.
It looks like the code only supports MNIST, Subsets, or an ImageFolder
out of the box. If you have a custom dataset you need to implement callback_get_label
.
If anyone reads this, this worked for me:
def callback_get_label(dataset, idx):
#callback function used in imbalanced dataset loader.
input, target = dataset[idx]
return np.argwhere(target.numpy()).item()
Edit: I suspect it would be faster to not cast the tensor to numpy, so the following change should do the same within the tensor framework:
def callback_get_label(dataset, idx):
#callback function used in imbalanced dataset loader.
input, target = dataset[idx]
return target.nonzero().item()
If you have an int label, try use this:
def callback_get_label(dataset, idx):
#callback function used in imbalanced dataset loader.
i, target = dataset[idx]
return int(target)
If you have an int label, try use this:
def callback_get_label(dataset, idx): #callback function used in imbalanced dataset loader. i, target = dataset[idx] return int(target)
For those who are new to Python like me: define the 'callback_get_label' function before you initialize the train_loader, and make your function like:
train_loader = DataLoader(dataset, ImbalancedDatasetSampler(dataset,callback_get_label = callback_get_label),batch_size = batch_size)
@ufoym this is solved, can be closed 🐰
If you have an int label, try use this:
def callback_get_label(dataset, idx): #callback function used in imbalanced dataset loader. i, target = dataset[idx] return int(target)
For those who are new to Python like me: define the 'callback_get_label' function before you initialize the train_loader, and make your function like: train_loader = DataLoader(dataset, ImbalancedDatasetSampler(dataset,callback_get_label = callback_get_label),batch_size = batch_size)
I got this error:
TypeError: callback_get_label() missing 1 required positional argument: 'idx'
could you tell where to define the callback_get_label() function?