ZhuiyiTechnology / t5-pegasus

中文生成式预训练模型

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

model.load_weights('./best_model.weights')报错'str' object has no attribute 'decode'

nicolas-echo opened this issue · comments

基本信息

你使用的操作系统: windows
你使用的Python版本: python3.6.0
你使用的Tensorflow版本: 1.14.0
你使用的Keras版本: 2.3.1
你使用的bert4keras版本: 0.10.0
你使用纯keras还是tf.keras:
你加载的预训练模型: T5_pegasus

核心代码

Run finetune.py的时候正常,得到权重best_model.weights,使用的是CSL数据集进行微调,rouge比博客给出的低了5-6%。进行推理时,根据微调脚本写了预测代码,加载模型代码如下: t5 = build_transformer_model(
config_path=config_path,
checkpoint_path=None,
model='t5.1.1',
return_keras_model=False,
name='T5',
)

encoder = t5.encoder
decoder = t5.decoder
model = t5.model
model.summary()
model.load_weights(save_mode_path, by_name=False,
skip_mismatch=False, reshape=False)`

输出信息

Traceback (most recent call last): File "F:/pycharm_projrct/t5-pegasus-main/preditc_01.py", line 130, in <module> skip_mismatch=True, reshape=False) File "E:\Anaconda3\envs\tens1\lib\site-packages\keras\engine\saving.py", line 492, in load_wrapper return load_function(*args, **kwargs) File "E:\Anaconda3\envs\tens1\lib\site-packages\keras\engine\network.py", line 1227, in load_weights reshape=reshape) File "E:\Anaconda3\envs\tens1\lib\site-packages\keras\engine\saving.py", line 1262, in load_weights_from_hdf5_group_by_name original_keras_version = f.attrs['keras_version'].decode('utf8') AttributeError: 'str' object has no attribute 'decode'

自我尝试

  • 尝试过修改load_weights的参数 by_name=True, skip_mismatch=True,好像没用
  • 尝试将模型保存为ckpt,使用
    # model.save_weights('./best_model.weights') # 保存模型 t5.save_weights_as_checkpoint('./best_model.weights') # 保存模型
    可以正常加载模型了但是预测出来的是空字符串。
    请苏神帮忙看一下

t5.save_weights_as_checkpoint('./save/best_model.ckpt') # 保存模型

上面保存模型的代码写错了,修改一下

自我尝试(更新)

修改keras/engine/saving报错出代码,decode前先encode一次,不报错了能成功运行,但是由于修改了keras代码,不知道对后续有什么影响

    if 'keras_version' in f.attrs:
        original_keras_version = f.attrs['keras_version'].encode('utf8').decode('utf8')#AttributeError: 'str' object has no attribute 'decode'
    else:
        original_keras_version = '1'
    if 'backend' in f.attrs:
        original_backend = f.attrs['backend'].encode('utf8').decode('utf8')#AttributeError: 'str' object has no attribute 'decode'