ucl-bug / jwave

A JAX-based research framework for differentiable and parallelizable acoustic simulations, on CPU, GPUs and TPUs

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

simulate_wave_propagation solver error when using accuracy that's not the default accuracy of 8

whajjali opened this issue · comments

Describe the bug
When testing the time domain simulate_wave_propagation solver for the case of FiniteDifference parameters and/or initial conditions, it looks like the solver automatically converts them to accuracy=8 no matter what accuracy is specified when defining the FiniteDifference objects. In particular, the smoothing step done for the initial condition converts it to accuracy=8 if a different accuracy was used when constructing the initial condition. If smooth_initial=False is chosen, then the following error comes up after the first time step:

"TypeError: Scanned function carry input and carry output must have the same pytree structure, but they differ: the input carry component fields[0] is a <class 'jaxdf.discretization.FiniteDifferences'> with pytree metadata ('params', 'domain'), ('accuracy',), (2,) but the corresponding component of the carry output is a <class 'jaxdf.discretization.FiniteDifferences'> with pytree metadata ('params', 'domain'), ('accuracy',), (8,), so the pytree node metadata does not match"

which means that the mass_conservation_rhs and momentum_conservation_rhs return FiniteDifference objects of accuracy=8 rather than the specified accuracy. Upon further investigation, I noticed that the replace_params method always defaults to the default accuracy of 8.

To Reproduce
Check the attached pdf file.
homogeneous_medium_FD_test.pdf