ChatBot-GPT2
简介
一个使用Pytorch
和Huggingface Transofrmers
构建的 gpt2
多轮聊天机器人
项目结构
process_data.py
: 处理四个格式不同的数据集的一些方法
load_data.py
: 调用process_data.py 将四个不同的数据集合并保存为json形式
mydataset.py
: 定义数据集以及与数据处理相关的函数
main.py
: 主函数
trainer.py
: 定义模型训练与验证方法
predictor.py
: 定义模型预测与交互方法
evaluator.py
: 定义评估标准包括(Bleu,Rouge, Distinct)
settings.py
: 项目配置参数
utils.py
: 工具类
如何使用
- 安装依赖库
pip install -r requirements.txt
- 下载并处理数据集
python load_data.py
- 训练模型
你可以从初始状态训练一个模型
python main.py --mode="train"
你也可以从一个保存过后的checkpoint
处开始训练(例如文件名为best.ckpt
)
python main.py --mode="train" --ckpt_name="best"
- 模型评估
python main.py --mode="evaluate" --ckpt_name="best"
- 推理和交互
python main.py --mode="infer" --ckpt_name="best"