Is it possible to provide a demo code for bert-base-chinese-qa?
WuJiunShiung opened this issue · comments
Hi, I am new in this field. Is it possible to provide a demo code for bert-base-chinese-qa?
I tried the following code, following the book "Getting Started with Google BERT":
from transformers import BertTokenizerFast, BertForQuestionAnswering
Tokenizer = BertTokenizerFast.from_pretrained("ckiplab/bert-base-chinese")
model = BertForQuestionAnswering.from_pretrained("ckiplab/bert-base-chinese-qa")
paragraph = "李同 也 沒有 在意 , 大廈 中 , 几乎 每 天 都 有 人 搬進 搬出 , 原 不足為奇 。 \
可是 , 當 李同 走進 大廈 時 , 卻 看見 了 那 個 老者 , 那 老者 是 倒退 著 身子 走出來 的 , \
在 那 老者 的 面前 , 兩 個 搬運 工人 , 正 抬 著 一 只 箱子 。 那 是 一 只 木 箱子 , \
很 殘舊 了 , 箱子 并 不 大 , 但是 兩 個 搬運 工人 抬 著 , 看來 十分 吃力 。[SEP]".strip(" ")
question = "[CLS]老者怎麼走出來的?[SEP]"
question_tokens = tokenizer.tokenize(question)
paragraph_tokens = tokenizer.tokenize(paragraph)
tokens = question_tokens + paragraph_tokens
input_ids = tokenizer.convert_tokens_to_ids(tokens)
segment_ids = [0] * len(question_tokens)
segment_ids += [1] * len(paragraph_tokens)
input_ids = torch.tensor([input_ids])
segment_ids = torch.tensor([segment_ids])
# Getting the answer
res = model(input_ids, token_type_ids=segment_ids)
start_scores, end_scores = res['start_logits'], res['end_logits']
start_index = torch.argmax(start_scores)
end_index = torch.argmax(end_scores)
print(" ".join(tokens[start_index:end_index+1]))
But, I got [CLS]. Could you provide a sample code to how how this Chinese QA model can work properly?
Thank you!