openmm / openmm-torch

OpenMM plugin to define forces with neural networks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Unit conversion in ForceModule

eva-not opened this issue · comments

Hi, I have trained a TorchMD-Net graph-network model with coordinates (in A) and forces (in kcal/mol A) as input. Since OpenMM-Torch requires positions in nm, energy in kJ/mol and forces in kJ/mol nm, I would like to do the conversions inside ForceModule:

class ForceModule(torch.nn.Module):
    def __init__(self, z):
        super(ForceModule, self).__init__()
        self.model = torch.jit.load('model.pt')
        self.z = z.cuda()

    def forward(self, positions):
        positions_nm = positions*0.1
        y, neg_dy = self.model(self.z, positions_nm)

        return y, neg_dy

I can convert the positions from A to nm, and also convert the energy after it has been computed from the model (e.g. y_kJ = y*4.184, then return y_kJ), but I'm not sure how to convert the forces. I tried multiplying neg_dy by 4.184 (the distance conversion from A to nm not needed again since the positions are in nm already), but when I try to add the force to the system:

module = torch.jit.script(ForceModule(z))
torch_force = TorchForce(module)
torch_force.setOutputsForces(True)
system.addForce(torch_force)

I get:
image

Isn't neg_dy supposed to be a torch.Tensor with shape (nparticles,3)?

Pytorch model file:
model.zip

commented

You do not have to worry about the forces if you let OpenMM-Torch compute them for you via backpropagation of the energies.
If you are using a TorchMD-Net model you can instruct it to not produce forces by passing derivative=False to load_model.

Having said that you are not providing the actual error, so I cannot know what is going on.
You should be able to write this:

class ForceModule(torch.nn.Module):
    def __init__(self, z):
        super(ForceModule, self).__init__()
        self.model = torch.jit.load('model.pt')
        self.z = z.cuda()

    def forward(self, positions):
        positions_nm = positions*0.1
        y, neg_dy = self.model(self.z, positions_nm)

        return y*4.184, neg_dy*4.184

Hi Raul, here is the complete code I'm using and the files needed to reproduce what's happening:

from openmm.app import *
from openmm import *
from openmm.unit import *
from sys import stdout
import numpy as np
import os
import torch

from openmmtorch import TorchForce

pdb = PDBFile('chignolin_CA_mod.pdb')
forcefield = ForceField('chignolin_CG.xml')

residues = list(pdb.topology.residues())
res_dict = dict((res, res.name) for res in residues)

system = forcefield.createSystem(pdb.topology, residueTemplates=res_dict)

from torchmdnet.models.model import load_model
model = load_model("epoch=35-val_loss=740.4694-test_loss=21.5844.ckpt")

torch.jit.script(model).save('model.pt')

class ForceModule(torch.nn.Module):
    def __init__(self, z):
        super(ForceModule, self).__init__()
        self.model = torch.jit.load('model.pt')
        self.z = z.cuda()

    def forward(self, positions):
        positions_nm = positions*0.1
        y, neg_dy = self.model(self.z, positions_nm)

        return y, neg_dy

z = [4, 4, 5, 8, 6, 13, 2, 13, 7, 4]
z = torch.tensor(z, dtype=torch.long)

module = torch.jit.script(ForceModule(z))
torch_force = TorchForce(module)
torch_force.setOutputsForces(True)
system.addForce(torch_force)

integrator = LangevinMiddleIntegrator(298.15*kelvin, 1/picosecond, 0.001*picoseconds)
platform = Platform.getPlatformByName('CUDA')
simulation = Simulation(pdb.topology, system, integrator, platform)
simulation.context.setPositions(pdb.positions)
simulation.minimizeEnergy()
simulation.reporters.append(DCDReporter('output_NNP.dcd', 100))
simulation.reporters.append(StateDataReporter(stdout, 100, step=True,
        potentialEnergy=True, temperature=True))
simulation.reporters.append(CheckpointReporter('checkpnt.chk', 5000))
simulation.step(300000)

If I do return y*4.184, neg_dy*4.184 as you suggested:

image

Files:
data.zip

What is the error? Your image shows where the error happened, but you cut off the error message.

Since your model returns forces in kcal/mol/A, you need to multiply by 41.84 to convert to kJ/mol/nm.

commented

The actual error is the following:

Traceback (most recent call last):
  File "/home/raul/Downloads/kk/script.py", line 39, in <module>
    module = torch.jit.script(module)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raul/miniforge3/envs/torchmdnet/lib/python3.11/site-packages/torch/jit/_script.py", line 1324, in script
    return torch.jit._recursive.create_script_module(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raul/miniforge3/envs/torchmdnet/lib/python3.11/site-packages/torch/jit/_recursive.py", line 559, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raul/miniforge3/envs/torchmdnet/lib/python3.11/site-packages/torch/jit/_recursive.py", line 636, in create_script_module_impl
    create_methods_and_properties_from_stubs(
  File "/home/raul/miniforge3/envs/torchmdnet/lib/python3.11/site-packages/torch/jit/_recursive.py", line 469, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(
RuntimeError: 
Arguments for call are not valid.
The following variants are available:
  
  aten::mul.Tensor(Tensor self, Tensor other) -> Tensor:
  Expected a value of type 'Tensor' for argument 'self' but instead found type 'Optional[Tensor]'.
  
  aten::mul.Scalar(Tensor self, Scalar other) -> Tensor:
  Expected a value of type 'Tensor' for argument 'self' but instead found type 'Optional[Tensor]'.
  
  aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!):
  Expected a value of type 'Tensor' for argument 'self' but instead found type 'Optional[Tensor]'.
  
  aten::mul.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!):
  Expected a value of type 'Tensor' for argument 'self' but instead found type 'Optional[Tensor]'.
  
  aten::mul.left_t(t[] l, int n) -> t[]:
  Could not match type Optional[Tensor] to List[t] in argument 'l': Cannot match List[t] to Optional[Tensor].
  
  aten::mul.right_(int n, t[] l) -> t[]:
  Expected a value of type 'int' for argument 'n' but instead found type 'Optional[Tensor]'.
  
  aten::mul.int(int a, int b) -> int:
  Expected a value of type 'int' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  aten::mul.complex(complex a, complex b) -> complex:
  Expected a value of type 'complex' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  aten::mul.float(float a, float b) -> float:
  Expected a value of type 'float' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  aten::mul.int_complex(int a, complex b) -> complex:
  Expected a value of type 'int' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  aten::mul.complex_int(complex a, int b) -> complex:
  Expected a value of type 'complex' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  aten::mul.float_complex(float a, complex b) -> complex:
  Expected a value of type 'float' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  aten::mul.complex_float(complex a, float b) -> complex:
  Expected a value of type 'complex' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  aten::mul.int_float(int a, float b) -> float:
  Expected a value of type 'int' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  aten::mul.float_int(float a, int b) -> float:
  Expected a value of type 'float' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  aten::mul(Scalar a, Scalar b) -> Scalar:
  Expected a value of type 'number' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  mul(float a, Tensor b) -> Tensor:
  Expected a value of type 'float' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  mul(int a, Tensor b) -> Tensor:
  Expected a value of type 'int' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  mul(complex a, Tensor b) -> Tensor:
  Expected a value of type 'complex' for argument 'a' but instead found type 'Optional[Tensor]'.

The original call is:
  File "/home/raul/Downloads/kk/script.py", line 33
        positions_nm = positions.to(torch.float32)*0.1
        y, neg_dy = self.model(self.z, positions_nm)
        return y*4.814, neg_dy*4.184
                        ~~~~~~~~~~~~ <--- HERE

You stumbled unto something weird. By all means this error looks to me like a bug in TorchScript. I am not sure if it is triggered by something specific in your checkpoint. The checkpoint looks perfectly valid, but I do not see this error with others I have around.
The error itself makes no sense to me.

While we investigate, you can make your script work by letting openmm-torch compute the forces from the energies:

from openmm.app import *
from openmm import *
from openmm.unit import *
from sys import stdout
import numpy as np
import os
import torch

from openmmtorch import TorchForce

pdb = PDBFile('chignolin_CA_mod.pdb')
forcefield = ForceField('chignolin_CG.xml')

residues = list(pdb.topology.residues())
res_dict = dict((res, res.name) for res in residues)

system = forcefield.createSystem(pdb.topology, residueTemplates=res_dict)
from typing import Tuple
from torchmdnet.models.model import load_model
model = load_model("a.ckpt", derivative=False)

torch.jit.script(model).save('model.pt')

class ForceModule(torch.nn.Module):
    def __init__(self, z):
        super(ForceModule, self).__init__()
        self.model = torch.jit.load('model.pt')
        self.z = z.cpu()

    def forward(self, positions):
        positions_nm = positions.to(torch.float32)*0.1
        y = self.model(self.z, positions_nm)[0]
        return y*4.814

z = [4, 4, 5, 8, 6, 13, 2, 13, 7, 4]
z = torch.tensor(z, dtype=torch.long)

module = ForceModule(z)
module = torch.jit.script(module)
torch_force = TorchForce(module)
torch_force.setOutputsForces(False)
system.addForce(torch_force)

integrator = LangevinMiddleIntegrator(298.15*kelvin, 1/picosecond, 0.001*picoseconds)
platform = Platform.getPlatformByName('CPU')
simulation = Simulation(pdb.topology, system, integrator, platform)
simulation.context.setPositions(pdb.positions)
simulation.minimizeEnergy()
commented

You can also solve the error by convinving TorchScript neg_dy is really really a Tensor and not an Optional[Tensor] by asserting it is not None:

    def forward(self, positions) -> Tuple[torch.Tensor, torch.Tensor]:
        positions_nm = positions.to(torch.float32)*0.1
        y, neg_dy = self.model(self.z, positions_nm)
        assert neg_dy is not None
        return y, 4.813*neg_dy

Still I cannot explain why this is the first time we see this, though.

EDIT: Frustratingly, adding this check internally inside the TorchMD_Net model is not enough.

Since your model returns forces in kcal/mol/A, you need to multiply by 41.84 to convert to kJ/mol/nm.

Aren't the nm taken care of when the positions are converted to nm before y and neg_dy are computed from the model?

While we investigate, you can make your script work by letting openmm-torch compute the forces from the energies:

Thanks, this works!

Aren't the nm taken care of when the positions are converted to nm before y and neg_dy are computed from the model?

You need to convert both inputs and outputs. You also have the wrong conversion factor. OpenMM passes positions in nm. You need to multiply by 10 (not 0.1) to convert to A. 1 nm = 10 A. Then your model returns forces in kcal/mol/A. You need to multiply by 41.84 to convert them to kJ/mol/nm, which is what OpenMM expects.

I see, thanks! I thought I had to convert the positions from A to nm in the input, this is why I was multiplying by 0.1.

When using a model that only returns energies, I only need to multiply the output (y) by 4.184 to get energies in kJ/mol, and then OpenMM will backpropagate the forces in kJ/mol/nm, is that right?

Right.

commented

I reduced the code to trigger the bug down to this:

import torch
from torch import Tensor
from typing import Tuple, Optional

class BaseModel(torch.nn.Module):

    def forward(self) -> Tuple[Tensor, Optional[Tensor]]:
        y = torch.randn(10)
        return y, y

class ChildModel(torch.nn.Module):
    def __init__(self):
        super(ChildModel, self).__init__()
        self.model = BaseModel()

    def forward(self) -> Tuple[Tensor, Tensor]:
        y1, y2 = self.model()
#        assert y2 is not None
        return y1, 2*y2

module = torch.jit.script(ChildModel())

Uncommenting the line makes the code succeed.
This makes it a TorchScript bug/limitation.

@peastman, this is neither a TorchMD-Net nor an OpenMM-Torch bug. However, it is easy for a user such as the OP to fall into it because the TorchMD_Net model is typed as returning -> Tuple[Tensor, Optional[Tensor]]. As a consequence, one cannot jit.script code which operates on the second tensor returned by TMDNet.
The fix in user code is simple, but it is unfortunate because it requires the user to know about a very superfluous implementation detail.
TMDNet-side we could make it so that the model returns an empty tensor (torch.empty(0)). This way we can type it as returnint Tuple[Tensor, Tensor].

Given the current direction of Pytorch I have zero confidence this will get fixed. So either we go around it or we should put it in some kind of FAQ.

commented

I was able to fix this one in
torchmd/torchmd-net#283
It will be there in the next release, soon.

commented

The release should be available in conda-forge in a few hours (0.15.2). Please feel free to reopen if you run into this again.