自动获取数据集样本类型的代码可能会出现误判
zanghu opened this issue · comments
🐛 bug 说明
首先感谢大佬开源如此优秀的工具!
版本:0.2.2
函数:uniem.data_structures.get_record_type
描述:该函数的逻辑似乎会导致程序将 TripletRecord 类型识别为 PairRecord
实验:
from uniem.data_structures import RecordType, get_record_type
sample = {
'text': 'hello',
'text_pos': 'hi',
'text_neg': 'no',
}
ret = get_record_type(sample)
print(ret.__class__)
print(ret)
返回:期望返回类型名称是 TripletRecord
<enum 'RecordType'>
RecordType.PAIR
Python Version
3.10
原因分析
-
RecordType类及包含元素
- <RecordType.PAIR: 'pair'> ['text', 'text_pos']
- <RecordType.TRIPLET: 'triplet'> ['text', 'text_pos', 'text_neg']
-
分析原因:这两类RecordType都含有 ['text', 'text_pos'],并且先判定sample中元素在<RecordType.PAIR: 'pair'> 中导致的。
-
有如下方案:
一、修改RecordType顺序,先判定TRIPLET即可
uniem\data_structures.py
record_type_cls_map: dict[RecordType, Any] = {
RecordType.TRIPLET: TripletRecord, # ['text', 'text_pos', 'text_neg']
RecordType.PAIR: PairRecord, # ['text', 'text_pos']
RecordType.SCORED_PAIR: ScoredPairRecord, # ['sentence1', 'sentence2', 'label']
}
二、修改get_record_type判断逻辑
uniem\data_structures.py
- 这样的一个缺点是sample中不能含有其他多余的key
def get_record_type(record: dict) -> RecordType:
record_type_field_names_map = {
record_type: [field.name for field in fields(record_cls)] for record_type, record_cls in
record_type_cls_map.items()
}
for record_type, field_names in record_type_field_names_map.items():
# if all(field_name in record for field_name in field_names):
if all(_record in field_names for _record in record):
return record_type
raise ValueError(f'Unknown record type, record: {record}')
三、修改PAIR和TRIPLET的变量名
uniem\data_structures.py
@dataclass(slots=True)
class TripletRecord:
text: str
text_triplet_pos: str
text_neg: str
非常感谢两位的讨论,目前的修复方案采用了 @vegaviazhang 的第一种方案。除此之外
- 添加了相应的测试用例以避免再一次出现
- 调整了函数的名称,改为
infer_record_type
,infer_
为前缀来明确函数的语义 - 在 FineTuner 中,添加
record_type
可选参数,允许用户显式的指定 record_type
由于这个 bug 影响比较大,uniem
也还没有完善的 contribute guide ,所以我自作主张的使用了 @vegaviazhang 的方案进行了快速修复。但从更长远的角度来看,这个 PR 如果能通过社区来贡献才是最好的。我会加快补全 uniem
的 contribute guide ~ 再次谢谢两位的参与
非常感谢两位的讨论,目前的修复方案采用了 @vegaviazhang 的第一种方案。除此之外
- 添加了相应的测试用例以避免再一次出现
- 调整了函数的名称,改为
infer_record_type
,infer_
为前缀来明确函数的语义- 在 FineTuner 中,添加
record_type
可选参数,允许用户显式的指定 record_type由于这个 bug 影响比较大,
uniem
也还没有完善的 contribute guide ,所以我自作主张的使用了 @vegaviazhang 的方案进行了快速修复。但从更长远的角度来看,这个 PR 如果能通过社区来贡献才是最好的。我会加快补全uniem
的 contribute guide ~ 再次谢谢两位的参与
收到,感谢支持!