tricktreat / piqn

Code for "Parallel Instance Query Network for Named Entity Recognition", accepted at ACL 2022.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

【matrix contains invalid numeric entries】匈牙利算法优化过程

kk19990709 opened this issue · comments

作者你好,非常喜欢你们在locate and label和piqn的工作,在复现过程中遇到了以下问题。如果你们遇到过类似的问题,可以告知一下解决方案吗?
由于硬件限制,我们使用了最新版的torchtransformer。希望不会带来影响。
在复现的过程中,我们遇到了以下错误。在github,CSDN个Stack Overflow查询后发现是匈牙利算法的优化问题。
错误发生在非常后面的epoch里,这给我们的debug带来了很大的困难。

Process SpawnProcess-1:
Traceback (most recent call last):
  File "/home/kk/anaconda3/envs/piqn/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/kk/anaconda3/envs/piqn/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/kk/piqn/piqn.py", line 15, in __train
    trainer.train(train_path=run_args.train_path[:-5]+ '_' + run_args.index + '.json',
  File "/home/kk/piqn/piqn/piqn_trainer.py", line 188, in train
    self._train_epoch(model, compute_loss, optimizer, train_dataset, updates_epoch, epoch)
  File "/home/kk/piqn/piqn/piqn_trainer.py", line 285, in _train_epoch
    batch_loss = compute_loss.compute(entity_logits, p_left, p_right, output, gt_types=batch['gt_types'], gt_spans = batch['gt_spans'], entity_masks=batch['entity_masks'], epoch = epoch,  deeply_weight = args.deeply_weight, seq_logits = masked_seq_logits, gt_seq_labels=batch['gt_seq_labels'], batch = batch)
  File "/home/kk/piqn/piqn/loss.py", line 59, in compute
    loss_dict = self.criterion(outputs, targets, epoch)
  File "/home/kk/anaconda3/envs/piqn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/kk/piqn/piqn/loss.py", line 261, in forward
    indices = self.matcher(outputs, targets)
  File "/home/kk/anaconda3/envs/piqn/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/kk/anaconda3/envs/piqn/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/kk/piqn/piqn/matcher.py", line 76, in forward
    indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
  File "/home/kk/piqn/piqn/matcher.py", line 76, in <listcomp>
    indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
  File "/home/kk/anaconda3/envs/piqn/lib/python3.8/site-packages/scipy/optimize/_lsap.py", line 93, in linear_sum_assignment
    raise ValueError("matrix contains invalid numeric entries")
ValueError: matrix contains invalid numeric entries

参考链接:
stack overflow
scipy发现了这个问题,并且似乎改进了它,但依然出错
scipy issue

piqn/matcher.pyline74进行如下修改

            if self.solver == "hungarian":
                C = C.cpu()
                indices = []
                for i, c in enumerate(C.split(sizes, -1)):
                    assert torch.isfinite(c[i]).all(), str(c[i])
                    indices.append(linear_sum_assignment(c[i]))
            if self.solver == "auction":
                indices = [auction_lap(c[i])[:2] for i, c in enumerate(C.split(sizes, -1))]

出现报错

AssertionError: tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]])

推断出,nan要不然不出现,要不然就出现一堆

你好,这个问题不在于匈牙利算法,而是cost矩阵出现了NaN。导致这个问题的原因有很多,你可以搜索一下“模型训练出现NaN的解决方案”来排除问题。

根据你给的情况,如果cost全部是NaN,可能是因为上一step的loss为NaN了,反向传播之后,模型参数也都为NaN,这时候计算出的cost也就是全部NaN了。

感谢解答,我去试一试!

感谢解答,我去试一试!

朋友最后解决了吗。是咋解决的。找不到bug啊。我是在genia上复现的。

想求教一下,这个问题怎么解决啊?

可能是cost在算predbox和targetbox的giou时候,由于predbox的xmin > xmax或ymin > ymax导致了iou出现了inf 或者nan

如果出现NaN问题,解除这两行注释试一下

piqn/piqn/loss.py

Lines 46 to 47 in 2f745b0

# if len(gt_types_wo_nil) == 0:
# return 0.1

thx,我看了下是我学习率太大了,所以网络输出时就已经时nan了,调小点就好了