How to load pretrained pytorch model.
shabashaash opened this issue · comments
Sorry for my stupidity, but i cant understand how to load model trained with pytorch? I have .pth file and can rewrite model.py file on GO (or save it to .pt with scripting), but how to do it generally?
Hi @shabashaash
There're 2 ways to load Python Pytorch model with gotch:
- Convert Python model to numpy (.npz) then gotch provides
ts.ReadNpz
method to read .npz file. See example/convert-model. - Save model to JIT then gotch provides APIs for both inferring and finetuning/training model. See example here and here
Thx for the answer!
Sorry again i know this isn't a stackoverflow or somethin, but I would really appreciate your help.
I saved model to ".npz" then to ".gt".
My model is EraseNet (its a GAN with VGG16 as feature extractor).
When i'm try to forward pass it I get this error: https://imgur.com/a/D6PpgaG
This is my code
imageNet := vision.NewImageNet()
device := gotch.CudaIfAvailable()
image, err := imageNet.LoadImage("path/to/img.jpg")
imageTs, err := vision.Resize(image, 512, 512)
usimage := imageTs.MustUnsqueeze(int64(0), false)
Img := usimage.MustTo(device, true)
if err != nil {
log.Fatal(err)
}
model, err := ts.ModuleLoad("path/to/model.gt")
if err != nil {
log.Fatal(err)
}
output := Img.ApplyCModule(model)
fmt.Printf("done")
fmt.Printf("%8.3f\n", output)
UPD:
Ok, Now im trying to trace it. Its gonna take a while because of many conflicts but if I manage to do it, i will write.
UPD2:
Ok i think i managed to script model successfully buuuut now i stuck with this problem.
All weights are clearly GPU.
Here state_dictionary of the model.
state.txt
Here is the GO code.
imageNet := vision.NewImageNet()
device := gotch.CudaIfAvailable()
image, err := imageNet.LoadImage("path/to/image.jpg")
imageTs, err := vision.Resize(image, 512, 512)
usimage := imageTs.MustUnsqueeze(0, true)
Img := usimage.MustTotype(gotch.Float, true)
Img = Img.MustTo(device, true)
if err != nil {
log.Fatal(err)
}
model, err := ts.ModuleLoad("path/to/scripted_model.pt")
if err != nil {
log.Fatal(err)
}
output := Img.ApplyCModule(model)
fmt.Printf("done")
fmt.Printf("%8.3f\n", output)
I don`t get where is the problem.
Look like your model weights loaded to CPU. Can you try to forward
image from CPU as well to see how you go first.
@sugarme,
thx for reply, but i fixed this error (idk why but model.cuda() did not convert all modules to cuda and i have to do it manually).
I think i managed to load the model successfully, buuuuut yet again something went wrong.
Output of my U-Net part of the model looks like this.
But outputs of other layers look much more normal. (Believe me this looks MUUUCH more like the ref image).
Also when im trying to load .pt in python i get even worse results:
So its more likely a scripting problem and not the library. I would appreciate the help but i think i should ask somewhere else.
@sugarme i think we can close this issue)