PaddlePaddle / PARL

A high-performance distributed training framework for Reinforcement Learning

Home Page:https://parl.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

在examples/PPO中,如何保存模型,用于推理?

wangyexiang opened this issue · comments

在跑例子中的PPO代码时,跑完后如何保存模型,用于推理?
参照例子中的DDPG代码,尝试了如下方式,报错AssertionError: model needs to implement forward method.
第一种:

save_inference_path = './inference_model'
input_shapes = [[None, obs_space.shape[0]]]
input_dtypes = ['float32']
agent.save_inference_model(save_inference_path, input_shapes, input_dtypes)

第二种:

save_inference_path = './inference_model'
input_shapes = [[None, obs_space.shape[0]]]
input_dtypes = ['float32']
agent.save_inference_model(save_inference_path, input_shapes, input_dtypes, model)

Hi,如报错提示所示,被保存的模型需要实现 forward 方法,因为PPO example是针对训练设计的,model仅提供了valuepolicy方法,不是针对评估推理设计的,没有forward函数,需要用户自定义选择推理流程。

因此,你需要在PPO的model中新增一个forward方法,比如你想保存policy的推理流程,可以增加如下代码

    # 新增 forward 方法,用于指定想要保存的推理过程
    def forward(self, obs):
        return self.policy(obs)

还有一个issue的问题和你的问题类似,可供参考:#1028

Hi,如报错提示所示,被保存的模型需要实现 forward 方法,因为PPO example是针对训练设计的,model仅提供了valuepolicy方法,不是针对评估推理设计的,没有forward函数,需要用户自定义选择推理流程。

因此,你需要在PPO的model中新增一个forward方法,比如你想保存policy的推理流程,可以增加如下代码

    # 新增 forward 方法,用于指定想要保存的推理过程
    def forward(self, obs):
        return self.policy(obs)

还有一个issue的问题和你的问题类似,可供参考:#1028

感谢您的解答,祝您工作顺利