suolyer / PyTorch_BERT_Biaffine_NER

论文复现《Named Entity Recognition as Dependency Parsing》

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

你好,我觉得评价指标有问题

FreeRotate opened this issue · comments

在metrics_span类中,对于识别出小于1的标签标为0,其他的标为1。0作为非实体,1为实体。但是标为1的标签除了不是“O”还可以是其他类标签。虽然识别出为实体,但可能与正确实体不为同一类实体,不应该作为识别正确吧。

使用这个才应当是正确的结果

def cal_metrics(y_preds, y_trues):
    """
        
    :param y_preds:
    :param y_trues:
    :return: 
    """
    y_preds_unique_labels = torch.unique(y_preds)
    y_trues_unique_labels = torch.unique(y_trues)

    all_labels = torch.cat((y_preds_unique_labels, y_trues_unique_labels)).unique(sorted=True)
    # ignore 0
    if 0 in all_labels:
        all_labels = all_labels[1:]

    y_preds_labels, y_preds_count = y_preds.unique(return_counts=True)
    y_trues_labels, y_trues_count = y_trues.unique(return_counts=True)

    corrects = torch.eq(y_preds, y_trues)
    corrects_labels, corrects_count = torch.mul(corrects, y_trues).unique(return_counts=True)

    y_preds_map = dict(zip(y_preds_labels.tolist(), y_preds_count.tolist()))
    y_true_map = dict(zip(y_trues_labels.tolist(), y_trues_count.tolist()))
    corrects_map = dict(zip(corrects_labels.tolist(), corrects_count.tolist()))
    precision, recall = 0, 0

    for label in all_labels.tolist():
        precision += (corrects_map.get(label, 0) / (y_preds_map.get(label, 0) + 1e-8))
        recall += (corrects_map.get(label, 0) / (y_true_map.get(label, 0) + 1e-8))

    precision, recall = precision / len(all_labels), recall / len(all_labels)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)
    return precision, recall, f1

使用这个才应当是正确的结果

def cal_metrics(y_preds, y_trues):
    """
        
    :param y_preds:
    :param y_trues:
    :return: 
    """
    y_preds_unique_labels = torch.unique(y_preds)
    y_trues_unique_labels = torch.unique(y_trues)

    all_labels = torch.cat((y_preds_unique_labels, y_trues_unique_labels)).unique(sorted=True)
    # ignore 0
    if 0 in all_labels:
        all_labels = all_labels[1:]

    y_preds_labels, y_preds_count = y_preds.unique(return_counts=True)
    y_trues_labels, y_trues_count = y_trues.unique(return_counts=True)

    corrects = torch.eq(y_preds, y_trues)
    corrects_labels, corrects_count = torch.mul(corrects, y_trues).unique(return_counts=True)

    y_preds_map = dict(zip(y_preds_labels.tolist(), y_preds_count.tolist()))
    y_true_map = dict(zip(y_trues_labels.tolist(), y_trues_count.tolist()))
    corrects_map = dict(zip(corrects_labels.tolist(), corrects_count.tolist()))
    precision, recall = 0, 0

    for label in all_labels.tolist():
        precision += (corrects_map.get(label, 0) / (y_preds_map.get(label, 0) + 1e-8))
        recall += (corrects_map.get(label, 0) / (y_true_map.get(label, 0) + 1e-8))

    precision, recall = precision / len(all_labels), recall / len(all_labels)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)
    return precision, recall, f1

感谢,之前这代码整体流程过了一遍,觉得这个代码有问题,太简略了,我自己把这函数改成了sklearn的评价指标,效果比baseline差了很多。除了评价指标有问题,原论文说的代码也没有这么简单,有些部分直接省略了,比如fasttext的embedding都没有用上。而且看了很多有关这篇论文的讨论,效果并不理想,换篇文章看吧。

sorry, 有一个地方我感觉写错了,应该改成:

corrects_mask = torch.eq(y_preds, y_trues)
corrects_labels, corrects_count = y_trues[corrects_mask].unique(return_counts=True)

baseline的话跑出来后才知道,不知道是不是没有计算0后导致,另外没有用fasttext的embedding,应该是由于使用了transformer这种预训练模型,从而不需要吧。