NoisyNets implementation issues
pseudo-rnd-thoughts opened this issue · comments
I'm implementing my own RL framework in Jax to better understand RL algorithms and found your code very helpful
Looking at the NoisyNets implementation, on line 316 and 317 (https://github.com/google/dopamine/blob/master/dopamine/jax/networks.py)
The same rng_key is used each time noise is generated meaning that no 'new' noise is generated each time an input is passed to the layer. In effect, the layer just applies a linear transform I think
This is a short testing example
import jax
import numpy as np
from dopamine.jax.networks import NoisyNetwork
if __name__ == '__main__':
rng = jax.random.PRNGKey(1)
rng, rng_net_def, rng_net_param = jax.random.split(rng, num=3)
net_def = NoisyNetwork(rng_key=rng_net_def, eval_mode=False)
net_params = net_def.init(rng_net_param, x=np.zeros(10), features=3)
state = np.random.random(10)
print(net_def.apply(net_params, x=state, features=3))
print(net_def.apply(net_params, x=state, features=3))
If this is an issue, then I implemented the following code for my framework
from typing import Sequence
import jax
import numpy as onp
import jax.numpy as jnp
from flax import linen as nn
class NoisyDense(nn.Module):
features: int
use_bias: bool = True
@staticmethod
@jax.jit
def _f(x: jnp.ndarray) -> jnp.ndarray:
# See (10) and (11) in Fortunato et al. (2018).
return jnp.multiply(jnp.sign(x), jnp.power(jnp.abs(x), 0.5))
@nn.compact
def __call__(self, inputs: onp.ndarray, eval_mode: bool = True, rng: jnp.DeviceArray = None) -> jnp.ndarray:
if eval_mode: # Turn off noise during evaluation
w_epsilon = jnp.zeros(shape=(inputs.shape[0], self.features), dtype=onp.float32)
b_epsilon = jnp.zeros(shape=(self.features,), dtype=onp.float32)
else: # Factored gaussian noise in (10) and (11) in Fortunato et al. (2018).
p_key, q_key = jax.random.split(rng)
p, q = jax.random.normal(p_key, [inputs.shape[0], 1]), jax.random.normal(q_key, [1, self.features])
f_p, f_q = self._f(p), self._f(q)
w_epsilon, b_epsilon = f_p * f_p, jnp.squeeze(f_q)
def _mu_init(key: jnp.DeviceArray, shape: Sequence[int]):
# Initialization of mean noise parameters (Section 3.2)
mean = 1 / jnp.power(inputs.shape[0], 0.5)
return jax.random.uniform(key, minval=-mean, maxval=mean, shape=shape)
def _sigma_init(_key: jnp.DeviceArray, shape: Sequence[int], dtype=jnp.float32):
# Initialization of sigma noise parameters (Section 3.2)
return jnp.ones(shape, dtype) * (0.1 / onp.sqrt(inputs.shape[0]))
# See (8) and (9) in Fortunato et al. (2018) for output computation.
w_mu = self.param('kernel_mu', _mu_init, (inputs.shape[0], self.features))
w_sigma = self.param('kernel_sigma', _sigma_init, (inputs.shape[0], self.features))
out = jnp.matmul(inputs, w_mu + jnp.multiply(w_sigma, w_epsilon))
if self.use_bias:
b_mu = self.param('bias_mu', _mu_init, (self.features,))
b_sigma = self.param('bias_sigma', _sigma_init, (self.features,))
out = out + b_mu + jnp.multiply(b_sigma, b_epsilon)
return out
Here is some similar testing code
if __name__ == '__main__':
rng = jax.random.PRNGKey(1)
rng, rng_net_def, rng_net_param = jax.random.split(rng, num=3)
net_def = NoisyDense(features=2)
net_params = net_def.init(rng_net_param, np.zeros(10))
state = np.random.random(10)
print(net_def.apply(net_params, inputs=state))
print(net_def.apply(net_params, inputs=state, eval_mode=False, rng=rng_net_def))
print(net_def.apply(net_params, inputs=state, eval_mode=False, rng=rng))
I would have submitted this as a pull request but noticed that you are not accepting merges
I also realized that this might be an issue. If we want to resample noise we should use either explicitly pass in a new rng every time or use self.make_rng to ensure that RNGs are split correctly.
Flax linen module variables are not able to be updated so the only way to have "new" random noise is to pass in the PRNG as a parameter like I have done in my example code
Edit: I understood the original comment incorrectly -- it was pointing out the correlated noise in Line 316 & 317 -- It's unclear how much impact it has on performance but will fix it - thanks for pointing it out! Also, this should fix it:
rng_p, rng_q = jax.random.split(self.rng_key, num=2)
p = NoisyNetwork.sample_noise(rng_p, [x.shape[0], 1])
q = NoisyNetwork.sample_noise(rng_q, [1, features])
I am not sure if this is a bug -- as @young-geng mentioned, if we want to resample noise, then we need to pass an explicit rng every time as done in the FullRainbowNetwork
here. That said, this does seem like a documentation issue about how we expect NoisyNets to work. @psc-g for further visibility
Here's a simplified example to verify that explicitly passing rng works:
class DummyNetwork(nn.Module):
"""Dummy network for testing NoisyNets."""
@nn.compact
def __call__(self, x, eval_mode=False, key=None):
if key is None:
key = jax.random.PRNGKey(int(time.time() * 1e6))
return NoisyNetwork(rng_key=key, eval_mode=eval_mode)(x, features=2)
def create_noisy_net_and_eval(num_runs=5):
network_def = DummyNetwork()
x = jnp.ones(5)
rng = jax.random.PRNGKey(0)
rng1, rng = jax.random.split(rng, 2)
params = network_def.init(rng1, x=x)
for i in range(num_runs):
rng1, rng = jax.random.split(rng)
print(f'rng{i}', network_def.apply(params, x=x, key=rng1))
>> create_noisy_net_and_eval()
rng0 [ 0.49825954 -0.5264382 ]
rng1 [ 0.3296632 -0.56998575]
rng2 [ 0.5706229 -0.42372862]
rng3 [ 0.5419281 -0.47531918]
rng4 [ 0.52439386 -0.46529555]
hi, thanks for raising this! i agree with what rishabh pointed out. i believe once the rngs used for p
and q
are uncorrelated, i believe it is working as expected (e.g. a new rng
is not passed in every time)
@agarwl Thanks, I hadn't spotted the FullRainbowNetwork implementation passed a new rng key to the noisy network each time so you are correct. With the modification that you propose then the noisy network works are expected
But as the eval_mode and rng_key are attributes of the network then it is potentially misleading as these are actually attributes that need to be passed to the call function every time. And in reverse, the features, use_bias and kernel_init should not be modified after init.
This is the reason that I shifted these variables from the init to call and vice versa in my implementation
@psc-g I may be wrong but I think a new rng should be passed every time (when eval_mode = False) as if new noise is not added each time then all that is happening is a linear transformation is being applied to the inputs.
In my eyes, defeating the point of the noisy network heuristic to both increase stability/resilience of the network and increase the "observations" seen by the network.
Now I see that it passes in a new RNG key every time so I believe I was wrong about the noise not being resampled and the implementation should be correct. Sorry for the confusion.