guillitte / pytorch-sentiment-neuron

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Code to load weights

ahirner opened this issue · comments

The code in models.py constructs the graph in a very sleek way. Is it possible to see how you transformed the weights into mlstm_ns.pt too?

This is the code I used to load the weight from numpy files 👍

embed.weight.data = torch.from_numpy(np.load("embd.npy"))
rnn.h2o.weight.data = torch.from_numpy(np.load("w.npy")).t()
rnn.h2o.bias.data = torch.from_numpy(np.load("b.npy"))
rnn.layers[0].wx.weight.data = torch.from_numpy(np.load("wx.npy")).t()
rnn.layers[0].wh.weight.data = torch.from_numpy(np.load("wh.npy")).t()
rnn.layers[0].wh.bias.data = torch.from_numpy(np.load("b0.npy"))
rnn.layers[0].wmx.weight.data = torch.from_numpy(np.load("wmx.npy")).t()
rnn.layers[0].wmh.weight.data = torch.from_numpy(np.load("wmh.npy")).t()

Thx for reverse engineering and sharing!

I haded the lm.py file allowing to retrain the model on new data. It was used to create the model and load the weights.

I tried to map the the original TF variables to the original .npy files to your .npy files. Is this mainly correct? Also, I wouldn't know how 14 and 15.npy would be used if they were (b0?) and which file corresponds to gmh in the pytorch version.

#Embedding for ASCII one-hot
embd = 0.npy = embed.npy

#State
wh = 1.npy = wh.npy
wmx = concat(2:6.npy) = wmx.npy
wmh = empty? = wmh.npy

gx = empty
gh = empty
gmx = empty
gmh = 7.npy = ?

wx = 8.npy = wx.npy
wh = 9.npy = wh.npy
wmx = 10.npy = wmx.npy
wmh = 11.npy = wmh.npy

#Fully connected
w = 12.npy = w.npy

Things are more complicated than this, because the tf model is using l2 regularization. Pytorch handles this differently. This is why I had to hack the tensorflow model to produce the different npy files.

Interesting, I assume you extracted the variables from a live TF graph then. I also found that L2 is added in pytorch's optimizer (usually?) and suspect that was the difference you talk about. Thanks!

Forgive me, it is not L2 regularization but weights normalization which is the problem. And yes, I extracted the variables with tf code.