aunum / goro

A High-level Machine Learning Library for Go

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Ways to export/import network weights

samuelvenzi opened this issue · comments

Does Goro support any way of exporting/saving and then importing the trained model?

Gorgonia has a way of doing this https://gorgonia.org/how-to/save-weights/ I haven't worked this in as first class citizen, but you should be able to call Learnables() to get the differentiable weights and use gob to save them.

If you wanted to write a cleaner method for this I would be happy to review the PR!

Saving the Learnables() was pretty straight-forward, but is there any method to import them after loading them from the file?

https://github.com/aunum/goro/blob/master/pkg/v1/model/model.go#L606 is probably the best example, that could be refactored to use a LoadLearnables() or SetLearnables() method.

Hey, I just created a SetLearnables() method inspired by CloneLearnablesTo(). I was trying, at first, to do the same thing outside Goro, but the need to access some unexported fields, like trainBatchChain, made me go for a fork and a PR #3. If you could review it please @pbarker.

thanks for the contribution! closing with merge #3