patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Right way to count cache size

AakashKumarNain opened this issue · comments

I am trying to keep a count of how many times the __call__ method of a pytree is compiled and what's the size of the the cache. For a jitted function in pure jax, we can simply check the cache size by looking at f._cache_size(). Is there an equivalent way that we can apply to eqx.filter_jit?

Take a look at eqx.debug.assert_max_traces. You can probably use a similar implementation.