pymc-devs / nutpie

Python wrapper for nuts-rs

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Worker process panic during sampling

fonnesbeck opened this issue · comments

Sampling with nutpie via the following:

compiled_model = nutpie.compile_pymc_model(model)
trace_pymc = nutpie.sample(
        compiled_model, 
        draws=1000,
        tune=2000,
        chains=4
)

Results in a failure, apparently during the adaptation phase?

thread '<unnamed>' panicked at 'assertion failed: val.is_finite()', /Users/cfonnesbeck/.cargo/registry/src/github.com-1ecc6299db9ec823/nuts-rs-0.3.0/src/adapt_strategy.rs:259:25
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
thread '<unnamed>' panicked at 'assertion failed: val.is_finite()', /Users/cfonnesbeck/.cargo/registry/src/github.com-1ecc6299db9ec823/nuts-rs-0.3.0/src/adapt_strategy.rs:259:25
thread '<unnamed>' panicked at 'assertion failed: val.is_finite()', /Users/cfonnesbeck/.cargo/registry/src/github.com-1ecc6299db9ec823/nuts-rs-0.3.0/src/adapt_strategy.rs:259:25
thread '<unnamed>' panicked at 'assertion failed: val.is_finite()', /Users/cfonnesbeck/.cargo/registry/src/github.com-1ecc6299db9ec823/nuts-rs-0.3.0/src/adapt_strategy.rs:259:25
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/Users/cfonnesbeck/Downloads/level_conversions_temp.py in <cell line: 5>()
      [437](file:///Users/cfonnesbeck/Downloads/level_conversions_temp.py?line=436)         # trace = sample_numpyro_nuts(
      [438](file:///Users/cfonnesbeck/Downloads/level_conversions_temp.py?line=437)         #     draws=1000, 
      [439](file:///Users/cfonnesbeck/Downloads/level_conversions_temp.py?line=438)         #     tune=9000, 
   (...)
     [442](file:///Users/cfonnesbeck/Downloads/level_conversions_temp.py?line=441)         #     chain_method='vectorized', 
     [443](file:///Users/cfonnesbeck/Downloads/level_conversions_temp.py?line=442)         #     postprocessing_backend='cpu')
     [444](file:///Users/cfonnesbeck/Downloads/level_conversions_temp.py?line=443)     compiled_model = nutpie.compile_pymc_model(model)
---> [445](file:///Users/cfonnesbeck/Downloads/level_conversions_temp.py?line=444)     trace_pymc = nutpie.sample(compiled_model, 
     [446](file:///Users/cfonnesbeck/Downloads/level_conversions_temp.py?line=445)         draws=1000,
     [447](file:///Users/cfonnesbeck/Downloads/level_conversions_temp.py?line=446)         tune=2000,
     [448](file:///Users/cfonnesbeck/Downloads/level_conversions_temp.py?line=447)         chains=4)
     [449](file:///Users/cfonnesbeck/Downloads/level_conversions_temp.py?line=448) else:
     [450](file:///Users/cfonnesbeck/Downloads/level_conversions_temp.py?line=449)     model = instantiate_model(data_model, use_gaussian=False)

File ~/GitHub/nutpie/nutpie/sample.py:143, in sample(compiled_model, draws, tune, chains, seed, num_try_init, save_warmup, store_divergences, progress_bar, init_mean, store_unconstrained, **kwargs)
    [140](file:///Users/cfonnesbeck/GitHub/nutpie/nutpie/sample.py?line=139)     return draws_data, infos
    [142](file:///Users/cfonnesbeck/GitHub/nutpie/nutpie/sample.py?line=141) try:
--> [143](file:///Users/cfonnesbeck/GitHub/nutpie/nutpie/sample.py?line=142)     draws_data, infos = do_sample()
    [144](file:///Users/cfonnesbeck/GitHub/nutpie/nutpie/sample.py?line=143) finally:
    [145](file:///Users/cfonnesbeck/GitHub/nutpie/nutpie/sample.py?line=144)     try:

File ~/GitHub/nutpie/nutpie/sample.py:120, in sample.<locals>.do_sample()
    [118](file:///Users/cfonnesbeck/GitHub/nutpie/nutpie/sample.py?line=117) num_divs = 0
    [119](file:///Users/cfonnesbeck/GitHub/nutpie/nutpie/sample.py?line=118) chains_tuning = chains
--> [120](file:///Users/cfonnesbeck/GitHub/nutpie/nutpie/sample.py?line=119) for draw, info in bar:
    [121](file:///Users/cfonnesbeck/GitHub/nutpie/nutpie/sample.py?line=120)     info_dict = info.as_dict()
    [122](file:///Users/cfonnesbeck/GitHub/nutpie/nutpie/sample.py?line=121)     if store_unconstrained:

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/fastprogress/fastprogress.py:47, in ProgressBar.__iter__(self)
     [45](file:///Users/cfonnesbeck/miniforge3/envs/pymc_env/lib/python3.10/site-packages/fastprogress/fastprogress.py?line=44) except Exception as e:
     [46](file:///Users/cfonnesbeck/miniforge3/envs/pymc_env/lib/python3.10/site-packages/fastprogress/fastprogress.py?line=45)     self.on_interrupt()
---> [47](file:///Users/cfonnesbeck/miniforge3/envs/pymc_env/lib/python3.10/site-packages/fastprogress/fastprogress.py?line=46)     raise e

File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/fastprogress/fastprogress.py:41, in ProgressBar.__iter__(self)
     [39](file:///Users/cfonnesbeck/miniforge3/envs/pymc_env/lib/python3.10/site-packages/fastprogress/fastprogress.py?line=38) if self.total != 0: self.update(0)
     [40](file:///Users/cfonnesbeck/miniforge3/envs/pymc_env/lib/python3.10/site-packages/fastprogress/fastprogress.py?line=39) try:
---> [41](file:///Users/cfonnesbeck/miniforge3/envs/pymc_env/lib/python3.10/site-packages/fastprogress/fastprogress.py?line=40)     for i,o in enumerate(self.gen):
     [42](file:///Users/cfonnesbeck/miniforge3/envs/pymc_env/lib/python3.10/site-packages/fastprogress/fastprogress.py?line=41)         if i >= self.total: break
     [43](file:///Users/cfonnesbeck/miniforge3/envs/pymc_env/lib/python3.10/site-packages/fastprogress/fastprogress.py?line=42)         yield o

ValueError: Worker process paniced.

Running on macOS, Python 3.10.4, PyMC 4.2.2, and nutpie built from current main branch.

Thought it might have been the progress bar, but disabling it results in the same error.

Thanks for the report!
This looks like for some reason we end up with a nan on the mass matrix. This PR makes sure that in those cases we don't panic at least: pymc-devs/nuts-rs#1
But I'd still like to know where that nan is coming from. Can you share the model and data for debugging, or is that difficult?

I can share it with you privately. Will send a link on Slack. Thanks!

Additional context: runs fine with sample_numpyro-nuts on GPU.

The NaNs seem to have been because the variance of the gradients was zero in one window (all gradients for a variable were exactly 1). This should be handled correctly now, after #13.