ucl-bug / jwave

A JAX-based research framework for differentiable and parallelizable acoustic simulations, on CPU, GPUs and TPUs

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

PSTD time varying solver not called even when a FourierSeries sound speed is provided

whajjali opened this issue · comments

Hi! I've been doing some tests with the wave equation solvers in time_varying.py and noticed that the simulate_wave_propagation function using the PSTD method is not used. In particular, copying lines 473 to 489 in time_varying.py:

@operator(init_params=fourier_wave_prop_params)
def simulate_wave_propagation(
    medium: Union[MediumAllScalars, MediumOnGrid],
    time_axis: TimeAxis,
    *,
    sources=None,
    sensors=None,
    u0=None,
    p0=None,
    checkpoint: bool = True,
    max_unroll_checkpoint: int = 10,
    smooth_initial=True,
    params=None,
):
    r"""Simulates the wave propagation operator using the PSTD method. This
    implementation is equivalent to the `kspaceFirstOrderND` function in the
    k-Wave Toolbox.

the medium types used for dispatching are Union[MediumAllScalars, MediumOnGrid] instead of Union[MediumAllScalars, MediumFourierSeries] which I was not sure if it was intended or a typo. In addition, even if it's changed to Union[MediumAllScalars, MediumFourierSeries], the PSTD function is still not used due to an incorrect dispatching bug in plum (might be related to the issue raised here) when a signature (Union[MediumAllScalars, MediumFourierSeries]) is more specialized than the one used earlier (MediumOnGrid).

Steps to reproduce the behavior
Running the differentdiscretizations.ipynb in the documentation

Desktop (please complete the following information):

  • OS: Ubuntu
  • Version 22.04

Additional context
This bug would most likely not affect the results since the subfunctions momentum_conservation_rhs and mass_conservation_rhs are dispatched correctly for FiniteDifference ssp vs FourierSeries ssp.

Thanks again for another great issue and, again, sorry for the delay in getting back at you..

The dispatch on Union[MediumAllScalars, MediumOnGrid] was actually wanted. In theory, it should not work with OnGrid values, because differential operators can't be defined for such kind of fields, so the idea was to default on the FourierSeries methods in this case. The alternative would be to raise an error when one uses OnGrid types. This is of course an arbitrary choice, do you see any problem with it?

In any case, the bug persists but I'm wondering why this was not picked up by the tests, I will look into it!

This should have been fixed in the #221 branch. Hopefully I will merge this into main soon :)

Hello! Thank you for the fixes above. Just merged the #221 branch into my version and it now dispatches correctly for FourierSeries vs. FiniteDifferences cases. Just curious why Medium[FiniteDifferences] was not explicitly used for type checking?

Regarding your question. The reason we don't explicity dispatch on FiniteDifferences is because currently jwave does not have any specialized FiniteDifferences methods for wave propagation. In fact, most operators fall on their OnGrid implementation for FiniteDifferences, while there are more specific ones for FourierSeries.

So at the moment, there's no need to explicitly dispatch to FiniteDifferences. As long as we are using a non-specialized implementation, this makes is easier for the users and maintainers to modify / update the code as needed. If at some point somebody contributes with a better version of the FiniteDifferences methods (which will be super welcome!!), then it will totally make sense to be more specific in the dispatch methods.

Does that make sense?

In any case, the fix should now be in the main branch and in the latest jwave release, so hopefully you don't have to keep separate versions anymore. I am closing this for now, but feel free to add comments if you have any other quesions or to reopen if something doesn't work.