wangyuxinwhy / uniem

unified embedding model

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

自动获取数据集样本类型的代码可能会出现误判

zanghu opened this issue · comments

🐛 bug 说明

首先感谢大佬开源如此优秀的工具!

版本:0.2.2
函数:uniem.data_structures.get_record_type
描述:该函数的逻辑似乎会导致程序将 TripletRecord 类型识别为 PairRecord
实验:

from uniem.data_structures import RecordType, get_record_type
sample = {
    'text': 'hello', 
    'text_pos': 'hi', 
    'text_neg': 'no',
}
ret = get_record_type(sample)
print(ret.__class__)
print(ret)

返回:期望返回类型名称是 TripletRecord
<enum 'RecordType'>
RecordType.PAIR

Python Version

3.10

原因分析

  • RecordType类及包含元素

    • <RecordType.PAIR: 'pair'> ['text', 'text_pos']
    • <RecordType.TRIPLET: 'triplet'> ['text', 'text_pos', 'text_neg']
  • 分析原因:这两类RecordType都含有 ['text', 'text_pos'],并且先判定sample中元素在<RecordType.PAIR: 'pair'> 中导致的。

    • image
  • 有如下方案:

一、修改RecordType顺序,先判定TRIPLET即可

uniem\data_structures.py

record_type_cls_map: dict[RecordType, Any] = {
    RecordType.TRIPLET: TripletRecord,  # ['text', 'text_pos', 'text_neg']
    RecordType.PAIR: PairRecord,  # ['text', 'text_pos']
    RecordType.SCORED_PAIR: ScoredPairRecord,  # ['sentence1', 'sentence2', 'label']
}

二、修改get_record_type判断逻辑

uniem\data_structures.py

  • 这样的一个缺点是sample中不能含有其他多余的key
def get_record_type(record: dict) -> RecordType:
    record_type_field_names_map = {
        record_type: [field.name for field in fields(record_cls)] for record_type, record_cls in
        record_type_cls_map.items()
    }
    for record_type, field_names in record_type_field_names_map.items():
        # if all(field_name in record for field_name in field_names):
        if all(_record in field_names for _record in record): 
            return record_type
    raise ValueError(f'Unknown record type, record: {record}')

三、修改PAIR和TRIPLET的变量名

uniem\data_structures.py

@dataclass(slots=True)
class TripletRecord:
    text: str
    text_triplet_pos: str
    text_neg: str

非常感谢两位的讨论,目前的修复方案采用了 @vegaviazhang 的第一种方案。除此之外

  1. 添加了相应的测试用例以避免再一次出现
  2. 调整了函数的名称,改为 infer_record_typeinfer_ 为前缀来明确函数的语义
  3. 在 FineTuner 中,添加 record_type 可选参数,允许用户显式的指定 record_type

2051687319637_ pic

由于这个 bug 影响比较大,uniem 也还没有完善的 contribute guide ,所以我自作主张的使用了 @vegaviazhang 的方案进行了快速修复。但从更长远的角度来看,这个 PR 如果能通过社区来贡献才是最好的。我会加快补全 uniem 的 contribute guide ~ 再次谢谢两位的参与

非常感谢两位的讨论,目前的修复方案采用了 @vegaviazhang 的第一种方案。除此之外

  1. 添加了相应的测试用例以避免再一次出现
  2. 调整了函数的名称,改为 infer_record_typeinfer_ 为前缀来明确函数的语义
  3. 在 FineTuner 中,添加 record_type 可选参数,允许用户显式的指定 record_type

2051687319637_ pic

由于这个 bug 影响比较大,uniem 也还没有完善的 contribute guide ,所以我自作主张的使用了 @vegaviazhang 的方案进行了快速修复。但从更长远的角度来看,这个 PR 如果能通过社区来贡献才是最好的。我会加快补全 uniem 的 contribute guide ~ 再次谢谢两位的参与

收到,感谢支持!