snakers4 / russian_stt_text_normalization

Russian text normalization pipeline for speech-to-text and other applications based on tagging s2s networks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RuntimeError в ноутбуке при повторном вызове norm_text

stllfe opened this issue · comments

commented

В ноутбке при повторном вызове метода модель падает с RuntimeError.
torch==1.8.0

Воспроизведение:

Python 3.8.0 (default, Jul 24 2020, 06:59:58)                                                                                                                  
Type 'copyright', 'credits' or 'license' for more information                                                                                                  
IPython 7.19.0 -- An enhanced Interactive Python. Type '?' for help. 

In [1]: from russian_stt_text_normalization.normalizer import Normalizer                                                                                       
                                                                                                                                                               
In [2]: norm = Normalizer(jit_model='../src/russian_stt_text_normalization/jit_s2s.pt')                                                                        
                                                                                                                                                               
In [3]: norm.norm_text('тестовый текст про 101 проблему')                                                                                                      
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.95s/it]
Out[3]: 'тестовый текст про сто один проблему'                                                                                                                 
                                                                                                                                                               
In [4]: norm.norm_text('тестовый текст про 101 проблему')                                                                                                      
  0%|                                                                                                                                    | 0/1 [00:00<?, ?it/s]
---------------------------------------------------------------------------                                                                                    
RuntimeError                              Traceback (most recent call last)                                                                                    
<ipython-input-4-11fb846935c0> in <module>                                                                                                                     
----> 1 norm.norm_text('тестовый текст про 101 проблему')                                                                                                      
                                                                                                                                                               
~/projects/asr/src/russian_stt_text_normalization/normalizer.py in norm_text(self, text)                                                                       
     95                 weighted_len = sum(weighted_string)                                                                                                    
     96                 if sum(weighted_string) <= self.max_len:                                                                                               
---> 97                     norm_parts.append(self._norm_string(part))                                                                                         
     98                 else:                                                                                                                                  
     99                     spaces = [m.start() for m in re.finditer(' ', part)]                                                                               
                                                                                                                                                               
~/projects/asr/src/russian_stt_text_normalization/normalizer.py in _norm_string(self, string)                                                                  
     70                                                                                                                                                        
     71         src = torch.LongTensor(src).unsqueeze(0).to(self.device)                                                                                       
---> 72         out = self.model(src, src2tgt)                                                                                                                 
     73         pred_words = self.decode_words(out, unk_list)                                                                                                  
     74         if len(pred_words) > 199:                                                                                                                      
                                                                                                                                                               
~/projects/asr/venv/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)                                               
    887             result = self._slow_forward(*input, **kwargs)
    888         else:                                                          
--> 889             result = self.forward(*input, **kwargs)    
    890         for hook in itertools.chain(                                   
    891                 _global_forward_hooks.values(),
                                                                               

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/test_jit2.py", line 333, in forward
            _120 = torch.select(scores0, 0, b0)
            _121 = torch.select(torch.select(_120, 0, 0), 0, d)
            _122 = torch.copy_(_121, _119)
                   ~~~~~~~~~~~ <--- HERE
          else:
            pass

Traceback of TorchScript, original code (most recent call last):
  File "/home/keras/notebook/nvme/islanna/ruhe_mono/models/seq2seq/jit_model.py", line 128, in forward
            for d in range(scores.shape[2]):
                if int(mask[b, 0, d].item()) == 0:
                    scores[b, 0, d] = -float('inf')
                    ~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

        # Turn scores to probabilities. 
RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.

Забыли сделать with torch.no_grad(): ...:
https://discuss.pytorch.org/t/leaf-variable-was-used-in-an-inplace-operation/308