junxiaosong / AlphaZero_Gomoku

An implementation of the AlphaZero algorithm for Gomoku (also called Gobang or Five in a Row)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

state_dict in pytorch isn't compatible with params with theano

GeneZC opened this issue · comments

state_dict in pytorch is a dict while params trained with theano dumped as list.
when you want to retrain the model trained with theano, it seems that the model can't be loaded properly.
is there any way to solve this?

Yes, the provided models in this repo were trained with Theano. If you want to load the models for pytorch, maybe you can load the list and rewrite it as a dict according to the state_dict format. The file policy_value_net_numpy.py may be helpful for you to figure out how the params are originally stored in a list.

The params of the pretrained Theano models can be transformed to the state_dict format of a PyTorch model by using the following script. Note that Theano conv2d flips the filters (rotate 180 degree) first while doing the calculation.

import pickle
from collections import OrderedDict
param_theano = pickle.load(open('best_policy_6_6_4.model', 'rb'))
keys = ['conv1.weight' ,'conv1.bias' ,'conv2.weight' ,'conv2.bias' ,'conv3.weight' ,'conv3.bias'  
    ,'act_conv1.weight' ,'act_conv1.bias' ,'act_fc1.weight' ,'act_fc1.bias'     
    ,'val_conv1.weight' ,'val_conv1.bias' ,'val_fc1.weight' ,'val_fc1.bias' ,'val_fc2.weight' ,'val_fc2.bias']
param_pytorch = OrderedDict()
for key, value in zip(keys, param_theano):
    if 'fc' in key and 'weight' in key:
        param_pytorch[key] = torch.FloatTensor(value.T)
    elif 'conv' in key and 'weight' in key:
        param_pytorch[key] = torch.FloatTensor(value[:,:,::-1,::-1].copy())
    else:
        param_pytorch[key] = torch.FloatTensor(value)

Sorry for reply so late. I just don't know exactly how state_dict is composed of, so I don't know how to rewrite it. Anyway, thx a lot for your code!

The params of the pretrained Theano models can be transformed to the state_dict format of a PyTorch model by using the following script. Note that Theano conv2d flips the filters (rotate 180 degree) first while doing the calculation.

import pickle
from collections import OrderedDict
param_theano = pickle.load(open('best_policy_6_6_4.model', 'rb'))
keys = ['conv1.weight' ,'conv1.bias' ,'conv2.weight' ,'conv2.bias' ,'conv3.weight' ,'conv3.bias'  
    ,'act_conv1.weight' ,'act_conv1.bias' ,'act_fc1.weight' ,'act_fc1.bias'     
    ,'val_conv1.weight' ,'val_conv1.bias' ,'val_fc1.weight' ,'val_fc1.bias' ,'val_fc2.weight' ,'val_fc2.bias']
param_pytorch = OrderedDict()
for key, value in zip(keys, param_theano):
    if 'fc' in key and 'weight' in key:
        param_pytorch[key] = torch.FloatTensor(value.T)
    elif 'conv' in key and 'weight' in key:
        param_pytorch[key] = torch.FloatTensor(value[:,:,::-1,::-1].copy())
    else:
        param_pytorch[key] = torch.FloatTensor(value)

我也尝试用pytorch load这个model file,把上面的代码加到
best_policy = PolicyValueNet(width, height, model_file = param_pytorch)
前面,但是pytorch load这个转换过的参数文件依然报错
File "/home/dofish/.local/lib/python3.7/site-packages/torch/serialization.py", line 189, in _check_seekable
raise_err_msg(["seek", "tell"], e)
File "/home/dofish/.local/lib/python3.7/site-packages/torch/serialization.py", line 182, in raise_err_msg
raise type(e)(msg)
AttributeError: 'collections.OrderedDict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

when I try to load a model trained by pytroch, I came across the same error, I try the follow code in <human_play.py>,line 65:

policy_param = model_file

And It worked