Is the last layer in distributional critic optimized (jax implementation)?
Jogima-cyber opened this issue · comments
Hi! I was wondering if the last layer in the distributional critic used in dmpo, for example, is optimized in the jax implementation, since it is instantiated in the call function and not in the init?
https://github.com/deepmind/acme/blob/98c4204b9be7f327035bf1dbef26aa820cc4c1ec/acme/jax/networks/distributional.py#L373-L414
It should, to check, you can take a look at the params initialized. It's a common pattern to create hk.Module inside call. You can find other examples in the Haiku examples.