Invalid Sampler in torch while using magnet-loss-test
faizwhb opened this issue · comments
I am using torch-1.0 and python3.6.
While using the magnet_loss.test as in: python magnet_loss_test.py --lr 1e-4 --batch-size 64 --mnist --magnet-loss
I am getting the error thrown for an invalid sampler: ValueError: sampler should be an instance of torch.utils.data.Sampler, but got sampler=<utils.sampler.SubsetSequentialSampler object at 0x7f6c2780def0>
`import torch
class SubsetSequentialSampler(torch.utils.data.Sampler):
"""Samples elements sequentially from a given list of indices, without replacement.
Arguments:
indices (list): a list of indices
"""
def __init__(self, indices, batch_indices):
self.indices = indices
self.batch_indices = batch_indices
def __iter__(self):
return (self.indices[i] for i in self.batch_indices)
def __len__(self):
return len(self.indices)
`
fixes the issue
`import torch
class SubsetSequentialSampler(torch.utils.data.Sampler):
"""Samples elements sequentially from a given list of indices, without replacement.Arguments: indices (list): a list of indices """ def __init__(self, indices, batch_indices): self.indices = indices self.batch_indices = batch_indices def __iter__(self): return (self.indices[i] for i in self.batch_indices) def __len__(self): return len(self.indices)
`
fixes the issue
Hi, I got the same error that you did. I didn't quite catch your solution, can you give me more explanation? I went to the directory from the error hint: “/usr/local/lib/python3.6/dist-packages/torch/utils/data/sampler.py", but cannot find text similar to the text you post.
Did you mean copying your text and pasting it to 'sampler.py' ???
Hoping for your repy, thank you!
`import torch
class SubsetSequentialSampler(torch.utils.data.Sampler):
"""Samples elements sequentially from a given list of indices, without replacement.Arguments: indices (list): a list of indices """ def __init__(self, indices, batch_indices): self.indices = indices self.batch_indices = batch_indices def __iter__(self): return (self.indices[i] for i in self.batch_indices) def __len__(self): return len(self.indices)
`
fixes the issueHi, I got the same error that you did. I didn't quite catch your solution, can you give me more explanation? I went to the directory from the error hint: “/usr/local/lib/python3.6/dist-packages/torch/utils/data/sampler.py", but cannot find text similar to the text you post.
Did you mean copying your text and pasting it to 'sampler.py' ???
Hoping for your repy, thank you!
Change the 'sampler.py' in 'utils/sampler.py'
THE CHANGE: class SubsetSequentialSampler(torch.utils.data.Sampler):