Expected object of scalar type Long but got scalar type Float for argument #3 'index'
bkowshik opened this issue · comments
- Chapter:
4
- Page:
82
- Section:
One-hot encoding
target_onehot = torch.zeros(target.shape[0], 10)
target_onehot.scatter_(1, target.unsqueeze(1), 1.0)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
in
1 target_onehot = torch.zeros(target.shape[0], 10)
----> 2 target_onehot.scatter_(1, target.unsqueeze(1), 1.0)
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #3 'index' in call to _th_scatter_
Hi @bkowshik, thanks for the issue. The error is saying that target
is a Float tensor, while it should be Long, since it's been used as an index.
Towards the end of page 81 we do convert target
to a Long tensor
# In[7]:
target = wineq[:, -1].long() target
# Out[7]:
tensor([6, 6, ..., 7, 6])
Are you sure you ran that part too?