JayYip / m3tl

BERT for Multitask Learning

Home Page:https://jayyip.github.io/m3tl/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

你好,请教nan的问题

keaideii opened this issue · comments

你好,感谢分享代码!想请教下,我们有几个任务进行多任务学习,在用这套框架一起训练时候,会报nan错误,但是单独训练时候都没有问题。不知道你可否知道可能哪里出现了问题?(看代码里面有注释说 # WARNING: Potential nan created here! # TODO: Fix this.)谢谢!

请问单独训练的时候也是用这个框架训练的吗?我在混合训练seq2seq任务和其他任务的时候遇到了nan问题,你的任务类型也是包含seq2seq吗?

你好,单独训练也是用的这个框架。不包含seq2seq,是多个cls的任务。

我好像没有遇到过这个情况, 一般来说, 比较常见的nan产生原因有: 学习率过大, 通常会表现为损失上升然后变nan; 设置的类别数小于实际类别数, 那么遇到超过设置类别数的类别就会产生nan.

那个warning那里应该是在seq2seq任务中, 如果没有抽样到该seq2seq任务的话, 其损失为nan. 但是cls任务应该是不会的.

感谢!初步定位到top_utils.py,tf.reduce_mean(batch_loss*loss_multiplier),传入reduce_mean的tensor为空,导致了nan,暂时先把这里的tensor做个判断,为空时传[0]。因为临时帮同事处理nan,还没仔细去读您的代码,感谢分享这么好的项目,随后再学习下!

好的, 如果你发现了问题根源或者觉得这个一个好的解决方法, 欢迎提个pr!

单纯的检测到nan传0可能会掩盖学习率过大的问题, 但是如果同时检测到batch_loss第一个维度为0的话, 感觉这个解决方案是可以的.

嗯,我也是想在找下这个问题的根源,找到再与您交流!