你好,我觉得评价指标有问题
FreeRotate opened this issue · comments
在metrics_span类中,对于识别出小于1的标签标为0,其他的标为1。0作为非实体,1为实体。但是标为1的标签除了不是“O”还可以是其他类标签。虽然识别出为实体,但可能与正确实体不为同一类实体,不应该作为识别正确吧。
使用这个才应当是正确的结果
def cal_metrics(y_preds, y_trues):
"""
:param y_preds:
:param y_trues:
:return:
"""
y_preds_unique_labels = torch.unique(y_preds)
y_trues_unique_labels = torch.unique(y_trues)
all_labels = torch.cat((y_preds_unique_labels, y_trues_unique_labels)).unique(sorted=True)
# ignore 0
if 0 in all_labels:
all_labels = all_labels[1:]
y_preds_labels, y_preds_count = y_preds.unique(return_counts=True)
y_trues_labels, y_trues_count = y_trues.unique(return_counts=True)
corrects = torch.eq(y_preds, y_trues)
corrects_labels, corrects_count = torch.mul(corrects, y_trues).unique(return_counts=True)
y_preds_map = dict(zip(y_preds_labels.tolist(), y_preds_count.tolist()))
y_true_map = dict(zip(y_trues_labels.tolist(), y_trues_count.tolist()))
corrects_map = dict(zip(corrects_labels.tolist(), corrects_count.tolist()))
precision, recall = 0, 0
for label in all_labels.tolist():
precision += (corrects_map.get(label, 0) / (y_preds_map.get(label, 0) + 1e-8))
recall += (corrects_map.get(label, 0) / (y_true_map.get(label, 0) + 1e-8))
precision, recall = precision / len(all_labels), recall / len(all_labels)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
return precision, recall, f1
使用这个才应当是正确的结果
def cal_metrics(y_preds, y_trues): """ :param y_preds: :param y_trues: :return: """ y_preds_unique_labels = torch.unique(y_preds) y_trues_unique_labels = torch.unique(y_trues) all_labels = torch.cat((y_preds_unique_labels, y_trues_unique_labels)).unique(sorted=True) # ignore 0 if 0 in all_labels: all_labels = all_labels[1:] y_preds_labels, y_preds_count = y_preds.unique(return_counts=True) y_trues_labels, y_trues_count = y_trues.unique(return_counts=True) corrects = torch.eq(y_preds, y_trues) corrects_labels, corrects_count = torch.mul(corrects, y_trues).unique(return_counts=True) y_preds_map = dict(zip(y_preds_labels.tolist(), y_preds_count.tolist())) y_true_map = dict(zip(y_trues_labels.tolist(), y_trues_count.tolist())) corrects_map = dict(zip(corrects_labels.tolist(), corrects_count.tolist())) precision, recall = 0, 0 for label in all_labels.tolist(): precision += (corrects_map.get(label, 0) / (y_preds_map.get(label, 0) + 1e-8)) recall += (corrects_map.get(label, 0) / (y_true_map.get(label, 0) + 1e-8)) precision, recall = precision / len(all_labels), recall / len(all_labels) f1 = 2 * precision * recall / (precision + recall + 1e-8) return precision, recall, f1
感谢,之前这代码整体流程过了一遍,觉得这个代码有问题,太简略了,我自己把这函数改成了sklearn的评价指标,效果比baseline差了很多。除了评价指标有问题,原论文说的代码也没有这么简单,有些部分直接省略了,比如fasttext的embedding都没有用上。而且看了很多有关这篇论文的讨论,效果并不理想,换篇文章看吧。
sorry, 有一个地方我感觉写错了,应该改成:
corrects_mask = torch.eq(y_preds, y_trues)
corrects_labels, corrects_count = y_trues[corrects_mask].unique(return_counts=True)
baseline的话跑出来后才知道,不知道是不是没有计算0后导致,另外没有用fasttext的embedding,应该是由于使用了transformer这种预训练模型,从而不需要吧。