google-deepmind / acme

A library of reinforcement learning components and agents

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Is the last layer in distributional critic optimized (jax implementation)?

Jogima-cyber opened this issue · comments

Hi! I was wondering if the last layer in the distributional critic used in dmpo, for example, is optimized in the jax implementation, since it is instantiated in the call function and not in the init?
https://github.com/deepmind/acme/blob/98c4204b9be7f327035bf1dbef26aa820cc4c1ec/acme/jax/networks/distributional.py#L373-L414

It should, to check, you can take a look at the params initialized. It's a common pattern to create hk.Module inside call. You can find other examples in the Haiku examples.