Don't modify torch default dtype
janosh opened this issue · comments
this line prevents running other models after MACE for relaxation in the same Python session since MACE recommends float64 for geometry optimization while e.g. chgnet and m3gnet use float32.
Line 145 in 88d49f9
error messages are not helpful so will likely take users time to troubleshoot this issue when encountered. only current workaround is to manually reset default dtype to float32 with
torch.set_default_dtype(torch.float32)
after every time MACE is called.
Suggested fix
only convert model inputs to model's dtype without modifying all float tensors everywhere
minimal example
import torch
from ase.build import bulk
from mace.calculators import mace_mp
orig_dtype = torch.get_default_dtype()
print(f"{orig_dtype=}")
>>> orig_dtype=torch.float32
atoms = bulk("Cu") * (2, 2, 2)
atoms.calc = mace_mp(default_dtype="float64")
atoms.get_potential_energy()
new_dtype = torch.get_default_dtype()
print(f"{new_dtype=}")
>>> orig_dtype=torch.float64