1ytic / warp-rnnt

CUDA-Warp RNN-Transducer

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

operating with apex?

tongjinle123 opened this issue · comments

I am try to use this implementation with apex half precision training, but it can't.
showing that it need float rather that half:


File "/data/asr_v3/src/model/transformer_transducer/lightning_model.py", line 41, in training_step
joint_out, rnnt_loss = self.forward(feature, feature_length, target, target_length, cal_rnnt_loss=True)
File "/opt/conda/lib/python3.7/site-packages/apex/amp/_initialize.py", line 197, in new_fwd
applier(kwargs, input_caster))
File "/data/asr_v3/src/model/transformer_transducer/lightning_model.py", line 36, in forward
joint_out, rnnt_loss = self.transducer.forward(feature, feature_length, target, target_length, cal_rnnt_loss)
File "/data/asr_v3/src/model/transformer_transducer/transformer_transducer.py", line 79, in forward
rnn_t_loss = self.cal_transducer_loss(joint, ori_token, feature_length, ori_token_length)
File "/data/asr_v3/src/model/transformer_transducer/transformer_transducer.py", line 108, in cal_transducer_loss
log_probs=log_prob, labels=target.int(), frames_lengths=frame_length.int(), labels_lengths=target_length.int(), reduction='mean')
File "/opt/conda/lib/python3.7/site-packages/warp_rnnt/init.py", line 80, in rnnt_loss
costs = RNNTLoss.apply(log_probs, labels, frames_lengths, labels_lengths, blank)
File "/opt/conda/lib/python3.7/site-packages/warp_rnnt/init.py", line 16, in forward
blank=blank,
RuntimeError: xs must be a Float tensor (rnnt_loss at binding.cpp:42)
frame #0: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x47 (0x7fa72c18c687 in /opt/conda/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: rnnt_loss(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, int) + 0xf79 (0x7fa707c87389 in /opt/conda/lib/python3.7/site-packages/warp_rnnt/_C.cpython-37m-x86_64-linux-gnu.so)
frame #2: + 0x22ea7 (0x7fa707c9aea7 in /opt/conda/lib/python3.7/site-packages/warp_rnnt/_C.cpython-37m-x86_64-linux-gnu.so)
frame #3: + 0x232ee (0x7fa707c9b2ee in /opt/conda/lib/python3.7/site-packages/warp_rnnt/_C.cpython-37m-x86_64-linux-gnu.so)
frame #4: + 0x1fd11 (0x7fa707c97d11 in /opt/conda/lib/python3.7/site-packages/warp_rnnt/_C.cpython-37m-x86_64-linux-gnu.so)

frame #10: THPFunction_apply(_object
, _object
) + 0x8d6 (0x7fa7601b9e96 in /opt/conda/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #63: __libc_start_main + 0xf0 (0x7fa76fc35830 in /lib/x86_64-linux-gnu/libc.so.6)

Yes, the loss function implemented only for float values. I have to generalize the implementation for other types. Currently, you can convert logits to float explicitly.