openmm / openmm-torch

OpenMM plugin to define forces with neural networks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Simplify the creation of `TorchForce`

raimis opened this issue · comments

In general, the creation of TorchForce goes like this:

import torch as pt
from openmmtorch import TorchForce

# 1. create a PyTorch module
class Module(pt.nn.Module):
    def __init__(self):
        super().__init__();
    def forward(self, positions):
        return pt.sum(positions)

# 2. save the module to a file
module = Module()
pt.jit.script(module).save('module.pt')

# 3. create a force from the file
force = TorchForce('module.pt')

We could simplify the last two steps by allowing to pass an instance of Module directly to TorchForce:

force = TorchForce(module)

For this, we need to solve #65.

Depending on the details of your model, you might need to use torch.jit.trace() instead of torch.jit.script(), which also requires providing example inputs.

Yes, both options should be available:

force = TorchForce('module.pt')
force = TorchForce(module)

In practice, I have never needed torch.jit.trace().

This ties in with #65 and #95. It was suggested to make serialization include the full model in the XML output. The problem is that a TorchForce object doesn't actually store a model. All it stores is a filename, which gets loaded each time you create a Context. So that would involve storing information in the XML that isn't actually present in the TorchForce object, leading to a variety of problems.

It could be useful to allow the TorchForce to actually store the model. We could add a constructor that takes a torch::jit::script::Module in the C++ API, or a torch.jit.ScriptModule in the Python API. There would likewise be a getModule() method to retrieve it. It would get stored in XML during serialization.

In the code above, the force creation would become

force = TorchForce(torch.jit.script(module))

or if you need to use tracing,

force = TorchForce(torch.jit.trace(module, exampleInputs))

cc @RaulPPelaez

commented

What should be done about the file name in this case?
It seems redundant to have it as a property of TorchForce if a getModule() member exists. An instance constructed with a module directly does not have a notion of a file name.

On the other hand TorchForce(string) can be implemented by calling TorchForce(torch::jit::script::module). The point being in both cases getModule() makes sense, while getFile() kind of looses meaning.

I propose the following solutions to this, trying to maintain retrocompatibility by not removing getFile(). Calling getFile() could either:

  1. Return an empty string if the module is constructed with a module.
  2. If not already done, write the module to a temporary file and return its name.
  3. Write the constructor as TorchForce(torch::jit::script::module, string = "defaultName"), and store that name.
  4. Change the meaning of getFile() to return some name assigned to the module.

To me 2 and 3 sound over engineered. 1 feels like the way to go.

Return an empty string if the module is constructed with a module.

This is what I was assuming it would do.

commented

This could be closed now