deep-learning-with-pytorch / dlwpt-code

Code for the book Deep Learning with PyTorch by Eli Stevens, Luca Antiga, and Thomas Viehmann.

Home Page:https://www.manning.com/books/deep-learning-with-pytorch

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Expected object of scalar type Long but got scalar type Float for argument #3 'index'

bkowshik opened this issue · comments

  • Chapter: 4
  • Page: 82
  • Section: One-hot encoding
target_onehot = torch.zeros(target.shape[0], 10)
target_onehot.scatter_(1, target.unsqueeze(1), 1.0)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
 in 
      1 target_onehot = torch.zeros(target.shape[0], 10)
----> 2 target_onehot.scatter_(1, target.unsqueeze(1), 1.0)

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #3 'index' in call to _th_scatter_

Hi @bkowshik, thanks for the issue. The error is saying that target is a Float tensor, while it should be Long, since it's been used as an index.

Towards the end of page 81 we do convert target to a Long tensor

# In[7]:
target = wineq[:, -1].long() target

# Out[7]:
tensor([6, 6, ..., 7, 6])

Are you sure you ran that part too?