THUIR / T2Ranking

T2Ranking: A large-scale Chinese benchmark for passage ranking.

Home Page:https://huggingface.co/datasets/THUIR/T2Ranking

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

msmarco_eval 的计算

yyht opened this issue · comments

commented

hi,我阅读了本项目的msmarco_eval,但好像和
https://github.com/PaddlePaddle/RocketQA/blob/main/research/RocketQAv2_EMNLP2021/metric/msmarco_eval.py
https://github.com/microsoft/MSMARCO-Document-Ranking/blob/main/ms_marco_eval.py
都不太一样:
def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages):
"""Compute MRR metric
Args:
p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping
Dict as read in with load_reference or load_reference_from_stream
p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates
Returns:
dict: dictionary of metrics {'MRR': }
"""
all_scores = {}
MRR = 0
qids_with_relevant_passages = 0
ranking = []
recall_q_top1 = set()
recall_q_top50 = set()
recall_q_all = set()

for qid in qids_to_ranked_candidate_passages:
    if qid in qids_to_relevant_passageids:
        ranking.append(0)
        target_pid = qids_to_relevant_passageids[qid]
        candidate_pid = qids_to_ranked_candidate_passages[qid]
        for i in range(0, MaxMRRRank):
            if candidate_pid[i] in target_pid:
                MRR += 1.0 / (i + 1)
                ranking.pop()
                ranking.append(i + 1)
                break
        for i, pid in enumerate(candidate_pid):
            if pid in target_pid:
                recall_q_all.add(qid)
                if i < 50:
                    recall_q_top50.add(qid)
                if i == 0:
                    recall_q_top1.add(qid)
                break
if len(ranking) == 0:
    raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?")

MRR = MRR / len(qids_to_relevant_passageids)
recall_top1 = len(recall_q_top1) * 1.0 / len(qids_to_relevant_passageids)
recall_top50 = len(recall_q_top50) * 1.0 / len(qids_to_relevant_passageids)
recall_all = len(recall_q_all) * 1.0 / len(qids_to_relevant_passageids)
all_scores['MRR @10'] = MRR
all_scores["recall@1"] = recall_top1
all_scores["recall@50"] = recall_top50
all_scores["recall@all"] = recall_all
all_scores['QueriesRanked'] = len(qids_to_ranked_candidate_passages)
return all_scores

百度的rocketqa里面,MRR 的分母是len(qids_to_relevant_passageids) 而本项目的是len(qids_to_ranked_candidate_passages)
百度的rocketqa里面, recall@50 等 都是 按照 qid计算,然后分母是len(qids_to_relevant_passageids);而本项目里面是 按照pid计算,且分母是所有的q-rel里面的passage之和,这个是有什么特殊的考虑么?
官方的MRR计算的分母是len(qids_to_relevant_passageids)(https://github.com/microsoft/MSMARCO-Document-Ranking/blob/main/ms_marco_eval.py)

1、百度的rocketqa里面,MRR 的分母是len(qids_to_relevant_passageids) 而本项目的是len(qids_to_ranked_candidate_passages):

因为T2Ranking里边的retrieval qrels是基于reranking的qrels生成的,即:0 1算作0;2 3算作1,导致有些query没有1的文档,所以这部分没有相关文档的query在算mrr指标的时候就过滤掉了,只算了有相关文档的queries的mrr。

2、百度的rocketqa里面, recall@50 等 都是 按照 qid计算,然后分母是len(qids_to_relevant_passageids);而本项目里面是 按照pid计算,且分母是所有的q-rel里面的passage之和,这个是有什么特殊的考虑么:

百度和msmarco的数据集因为每个query平均只有1个相关文档,所以召回率可以直接除以len(qids_to_relevant_passageids),但是我们的数据集因为每个query有很多个相关文档,所以在我们的数据集上算recall应该除以相关文档的个数。