Potential bug saving parametric ROM class.
eparish1 opened this issue · comments
Describe the bug
I am trying to save an "InterpolatedContinuousOpInfROM" I have fit to file. There seems to be an issue with the save function here. When I try and save, I get an error about a LinearOperator object not having the parameters attribute.
To Reproduce
import numpy as np
import opinf
N = 10
K = 5
n_snaps = 8
Phi = np.random.normal(size=(N,K))
states = np.random.normal(size=(N,n_snaps))
states_dot = np.random.normal(size=(N,n_snaps))
params = np.array([1])
rom = opinf.InterpolatedContinuousOpInfROM("A", InterpolatorClass='auto')
rom.fit(Phi,parameters=params, states=[states], ddts=[states_dot])
rom.save('my_model.h5',overwrite=True)
Expected behavior
From the documentation, it looks like the model should save and then be loadable.
Output
Traceback (most recent call last):
File "test_save.py", line 13, in
rom.save('my_model.h5',overwrite=True)
File "/Users/ejparis/.local/lib/python3.9/site-packages/opinf/roms/interpolate/base.py", line 382, in save
hf.create_dataset(f"operators/{key}", data=op.matrices)
File "/Users/ejparis/.local/lib/python3.9/site-packages/opinf/roms/interpolate/_base.py", line 381, in save
hf.create_dataset("parameters", data=op.parameters)
AttributeError: 'LinearOperator' object has no attribute 'parameters'
Additional context
I have no issues in the non-parametric context.
Thanks for sharing @eparish1, I'll take a look at this.
Error occurs here:
# src/opinf/roms/interpolate/_base.py
class _InterpolatedOpInfROM(_BaseParametricROM):
# ...
def save(self, savefile, save_basis=True, overwrite=False):
# ...
with hdf5_savehandle(savefile, overwrite=overwrite) as hf:
# ...
for key, op in zip(self.modelform, self):
if "parameters" not in hf:
> hf.create_dataset("parameters", data=op.parameters)
The for
loop has a single iteration with key = "A"
and op = self.A_
. Here op
is assumed to be an InterpolatedLinearOperator
(which has a parameters
attribute), but in this case it's a (non-parametric) LinearOperator
because there is only one parameter value (i.e., no interpolation).
# src/opinf/roms/interpolate/_base.py
class _InterpolatedOpInfROM(_BaseParametricROM):
# ...
def _interpolate_roms(self, parameters, roms):
# ...
for key in self.modelform:
attr = f"{key}_"
ops = [getattr(rom, attr).entries for rom in roms]
> if all(np.all(ops[0] == op) for op in ops):
> # This operator does not depend on the parameters.
> OperatorClass = operators.nonparametric_operators[key]
> setattr(self, attr, OperatorClass(ops[0]))
else:
# This operator varies with the parameters (so interpolate).
OperatorClass = operators.interpolated_operators[key]
setattr(self, attr, OperatorClass(parameters, ops,
self.InterpolatorClass))
The fact that A_
is a LinearOperator
is fine, but the save()
method needs to be updated to handle nonparametric operators. The load()
method may also have the same issue.
(Note to self: the operator-centric refactor should prevent issues like this)
Nice, thanks Shane!