SivilTaram / Persona-Dialogue-Generation

The code of ACL 2020 paper "You Impress Me: Dialogue Generation via Mutual Persona Perception"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ERROR occurs when running train_psquare.py

Lireanstar opened this issue · comments

Hi, after i train the recevier and transmitter model ,then i run the train_psquare.py
In my local environment, I have two cards, and I run it in my terminal according to the following code:

CUDA_VISIBLE_DEVICES=0,1 python train_psquare.py

Then the errors occur as below:

[loading fbdialog data:/home/Persona-Dialogue-Generation/data/ConvAI2/train_self_original_no_cands.txt]
[loading fbdialog data:/home/Persona-Dialogue-Generation/data/ConvAI2/train_self_original_selfplay.txt]
[ training... ]
.> [ Saving tensorboard logs here: ./tmp/psquare/tensorboard ]
/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.

Traceback (most recent call last):
File "train_psquare.py", line 105, in
TrainLoop(opt).train()
File "/home/Persona-Dialogue-Generation/scripts/train_model_selfplay.py", line 270, in train
world.parley_episode(is_training=True, is_display=is_display)
File "/home/Persona-Dialogue-Generation/worlds/selfplay.py", line 186, in parley_episode
self.parley(is_display)
File "/home/Persona-Dialogue-Generation/worlds/selfplay.py", line 90, in parley
acts[0] = agents[0].act(is_display)
File "/home/Persona-Dialogue-Generation/agents/psquare/psquare.py", line 604, in act
act = self.batch_act(self.observation)
File "/home/Persona-Dialogue-Generation/agents/psquare/psquare.py", line 624, in batch_act
cand_inds, is_training)
File "/home/Persona-Dialogue-Generation/agents/psquare/psquare.py", line 814, in transmitter_predict
raise e
File "/home/Persona-Dialogue-Generation/agents/psquare/psquare.py", line 787, in transmitter_predict
sampling=True)
File "/home/Persona-Dialogue-Generation/agents/transmitter/gpt/model.py", line 120, in forward
predictions, scores, hidden_states = self.sample_decoding(batch_size, prior_context, prior_dis, self.topk)
File "/home/Persona-Dialogue-Generation/agents/transmitter/gpt/model.py", line 383, in sample_decoding
is_end = is_end | (predict_tok == self.end_idx).view(-1)

RuntimeError: Expected object of scalar type Byte but got scalar type Bool for argument #2 'other' in call to _th_or

How can i fix this error? thanks!

commented

@Libincn-HNU What is the version number of PyTorch? We recommend you to use PyTorch 1.0.

@SivilTaram I use the torch 1.3.0, i try to change the version to fix it

commented

@Libincn-HNU Looking forward to your further feedback.

Hi, i change the source code :
is_end = is_end.bool() | (predict_tok == self.end_idx).bool().view(-1)
finally it works, ignoring the warning:

/pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.

Finally, it appears:

[ time:127s parleys:54 ] {'reward_var': 0.35737055998582107, 'reward': 0.0036568641662597656, 'num_selfplay_episode': 13, 'num_selfplay_turns': 78, 'total_reward': -0.01236748007627634}

Besides, i lowered the batch_size : )