vithursant / MagnetLoss-PyTorch

PyTorch implementation of a deep metric learning technique called "Magnet Loss" from Facebook AI Research (FAIR) in ICLR 2016.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Invalid Sampler in torch while using magnet-loss-test

faizwhb opened this issue · comments

I am using torch-1.0 and python3.6.
While using the magnet_loss.test as in: python magnet_loss_test.py --lr 1e-4 --batch-size 64 --mnist --magnet-loss

I am getting the error thrown for an invalid sampler: ValueError: sampler should be an instance of torch.utils.data.Sampler, but got sampler=<utils.sampler.SubsetSequentialSampler object at 0x7f6c2780def0>

`import torch

class SubsetSequentialSampler(torch.utils.data.Sampler):
"""Samples elements sequentially from a given list of indices, without replacement.

Arguments:
    indices (list): a list of indices
"""

def __init__(self, indices, batch_indices):
    self.indices = indices
    self.batch_indices = batch_indices

def __iter__(self):
    return (self.indices[i] for i in self.batch_indices)

def __len__(self):
    return len(self.indices)

`
fixes the issue

commented

`import torch

class SubsetSequentialSampler(torch.utils.data.Sampler):
"""Samples elements sequentially from a given list of indices, without replacement.

Arguments:
    indices (list): a list of indices
"""

def __init__(self, indices, batch_indices):
    self.indices = indices
    self.batch_indices = batch_indices

def __iter__(self):
    return (self.indices[i] for i in self.batch_indices)

def __len__(self):
    return len(self.indices)

`
fixes the issue

Hi, I got the same error that you did. I didn't quite catch your solution, can you give me more explanation? I went to the directory from the error hint: “/usr/local/lib/python3.6/dist-packages/torch/utils/data/sampler.py", but cannot find text similar to the text you post.

Did you mean copying your text and pasting it to 'sampler.py' ???

Hoping for your repy, thank you!

commented

`import torch
class SubsetSequentialSampler(torch.utils.data.Sampler):
"""Samples elements sequentially from a given list of indices, without replacement.

Arguments:
    indices (list): a list of indices
"""

def __init__(self, indices, batch_indices):
    self.indices = indices
    self.batch_indices = batch_indices

def __iter__(self):
    return (self.indices[i] for i in self.batch_indices)

def __len__(self):
    return len(self.indices)

`
fixes the issue

Hi, I got the same error that you did. I didn't quite catch your solution, can you give me more explanation? I went to the directory from the error hint: “/usr/local/lib/python3.6/dist-packages/torch/utils/data/sampler.py", but cannot find text similar to the text you post.

Did you mean copying your text and pasting it to 'sampler.py' ???

Hoping for your repy, thank you!

Change the 'sampler.py' in 'utils/sampler.py'

THE CHANGE: class SubsetSequentialSampler(torch.utils.data.Sampler):