philtabor / Deep-Q-Learning-Paper-To-Code

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

IndexError: tensors used as indices must be long, byte or bool tensors

Ningizza opened this issue · comments

When I run dqn_agent.py to train the DQN for atari pong I get the following error:

(learn36) D:\deepleren\Learn36>python main_dqn.py
Traceback (most recent call last):
File "main_dqn.py", line 40, in
agent.learn()
File "D:\deepleren\Learn36\dqn_agent.py", line 89, in learn
q_pred = self.q_eval.forward(states)[indices, actions]
IndexError: tensors used as indices must be long, byte or bool tensors

which versions of numpy and pytorch are you using? This is certainly an issue with package versions.

My numpy version is 1.18.1
My torch version is 1.4.0

OK, I can't replicate the error but it looks like there are two possibilities.

  1. indices are not the correct dtype
  2. actions are not the correct dtype

Please verify the dtype of indices and actions by printing indices.dtype and actions.dtype to the terminal.

Make sure they are both np.int64.

If one (or both) are not, then change the declaration of indices to include dtype=np.int64 in line 87
and also change the declaration of action_memory in replay_memory.py to be dtype=np.int64

Also, for good measure (because it generates a warning with this combination of packages), change the dtype of terminal_memory to be np.bool (also in the replay_memory.py file).

Let me know if that helps!

The datatype of "indices" was int32

I changed
indices = np.arange(self.batch_size)
into
indices = np.arange(self.batch_size, dtype=np.int64)

The datatype of actions was torch.int32

I changed
actions = T.tensor(action).to(self.q_eval.device)
into
actions = T.tensor(action, dtype=T.long).to(self.q_eval.device)

After these changes I got a new error " indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead."

This problem seems to be solved by changing:
self.terminal_memory = np.zeros(self.mem_size, dtype=np.uint8)
into
self.terminal_memory = np.zeros(self.mem_size, dtype=np.bool)

Now the agent is training to play Pong :)

Thanks very much for your help, I appreciate it a lot!

Perfect. I have to set up a requirements file for this, as these errors didn't happen on earlier versions of PyTorch.