Right way to count cache size
AakashKumarNain opened this issue · comments
Aakash Kumar Nain commented
I am trying to keep a count of how many times the __call__
method of a pytree is compiled and what's the size of the the cache. For a jitted function in pure jax, we can simply check the cache size by looking at f._cache_size()
. Is there an equivalent way that we can apply to eqx.filter_jit
?
Patrick Kidger commented
Take a look at eqx.debug.assert_max_traces
. You can probably use a similar implementation.
Aakash Kumar Nain commented
Cool. Thanks @patrick-kidger