Different results when running PPO with the same seed multiple times?
Chulabhaya opened this issue · comments
Hey all! I'm trying to track down a seeming reproducibility issue I'm having with the PPO implementation after I added some simple WandB logging. I ran the same code 7 times, and 4 of the times the results are identical. However, 3 of the times the results differ:
Would anyone have any ideas as to why this might be happening?
Here's my slightly modified code:
from typing import NamedTuple, Sequence
import distrax
import flax.linen as nn
import gymnax
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
from purejaxrl.wrappers import FlattenObservationWrapper, LogWrapper
import wandb
from jax import config
#config.update("jax_disable_jit", True)
class ActorCritic(nn.Module):
action_dim: Sequence[int]
activation: str = "tanh"
@nn.compact
def __call__(self, x):
if self.activation == "relu":
activation = nn.relu
else:
activation = nn.tanh
actor_mean = nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(x)
actor_mean = activation(actor_mean)
actor_mean = nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(actor_mean)
actor_mean = activation(actor_mean)
actor_mean = nn.Dense(
self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
)(actor_mean)
pi = distrax.Categorical(logits=actor_mean)
critic = nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(x)
critic = activation(critic)
critic = nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(critic)
critic = activation(critic)
critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
critic
)
return pi, jnp.squeeze(critic, axis=-1)
class Transition(NamedTuple):
done: jnp.ndarray
action: jnp.ndarray
value: jnp.ndarray
reward: jnp.ndarray
log_prob: jnp.ndarray
obs: jnp.ndarray
info: jnp.ndarray
def make_train(config):
config["NUM_UPDATES"] = (
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
)
config["MINIBATCH_SIZE"] = (
config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
)
env, env_params = gymnax.make(config["ENV_NAME"])
env = FlattenObservationWrapper(env)
env = LogWrapper(env)
def linear_schedule(count):
frac = (
1.0
- (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
/ config["NUM_UPDATES"]
)
return config["LR"] * frac
def train(rng):
# INIT NETWORK
network = ActorCritic(
env.action_space(env_params).n, activation=config["ACTIVATION"]
)
rng, _rng = jax.random.split(rng)
init_x = jnp.zeros(env.observation_space(env_params).shape)
network_params = network.init(_rng, init_x)
if config["ANNEAL_LR"]:
tx = optax.chain(
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
optax.adam(learning_rate=linear_schedule, eps=1e-5),
)
else:
tx = optax.chain(
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
optax.adam(config["LR"], eps=1e-5),
)
train_state = TrainState.create(
apply_fn=network.apply,
params=network_params,
tx=tx,
)
# INIT ENV
rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)
# TRAIN LOOP
def _update_step(runner_state, unused):
# COLLECT TRAJECTORIES
def _env_step(runner_state, unused):
train_state, time_state, env_state, last_obs, rng = runner_state
# SELECT ACTION
rng, _rng = jax.random.split(rng)
pi, value = network.apply(train_state.params, last_obs)
action = pi.sample(seed=_rng)
log_prob = pi.log_prob(action)
# STEP ENV
rng, _rng = jax.random.split(rng)
rng_step = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state, reward, done, info = jax.vmap(
env.step, in_axes=(0, 0, 0, None)
)(rng_step, env_state, action, env_params)
transition = Transition(
done, action, value, reward, log_prob, last_obs, info
)
runner_state = (train_state, time_state, env_state, obsv, rng)
return runner_state, transition
runner_state, traj_batch = jax.lax.scan(
_env_step, runner_state, None, config["NUM_STEPS"]
)
# CALCULATE ADVANTAGE
train_state, time_state, env_state, last_obs, rng = runner_state
_, last_val = network.apply(train_state.params, last_obs)
def _calculate_gae(traj_batch, last_val):
def _get_advantages(gae_and_next_value, transition):
gae, next_value = gae_and_next_value
done, value, reward = (
transition.done,
transition.value,
transition.reward,
)
delta = reward + config["GAMMA"] * next_value * (1 - done) - value
gae = (
delta
+ config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
)
return (gae, value), gae
_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value
advantages, targets = _calculate_gae(traj_batch, last_val)
# UPDATE NETWORK
def _update_epoch(update_state, unused):
def _update_minbatch(train_state, batch_info):
traj_batch, advantages, targets = batch_info
def _loss_fn(params, traj_batch, gae, targets):
# RERUN NETWORK
pi, value = network.apply(params, traj_batch.obs)
log_prob = pi.log_prob(traj_batch.action)
# CALCULATE VALUE LOSS
value_pred_clipped = traj_batch.value + (
value - traj_batch.value
).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
value_losses = jnp.square(value - targets)
value_losses_clipped = jnp.square(value_pred_clipped - targets)
value_loss = (
0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
)
# CALCULATE ACTOR LOSS
ratio = jnp.exp(log_prob - traj_batch.log_prob)
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
loss_actor1 = ratio * gae
loss_actor2 = (
jnp.clip(
ratio,
1.0 - config["CLIP_EPS"],
1.0 + config["CLIP_EPS"],
)
* gae
)
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
loss_actor = loss_actor.mean()
entropy = pi.entropy().mean()
total_loss = (
loss_actor
+ config["VF_COEF"] * value_loss
- config["ENT_COEF"] * entropy
)
return total_loss, (value_loss, loss_actor, entropy)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
total_loss, grads = grad_fn(
train_state.params, traj_batch, advantages, targets
)
train_state = train_state.apply_gradients(grads=grads)
return train_state, total_loss
train_state, traj_batch, advantages, targets, rng = update_state
rng, _rng = jax.random.split(rng)
# Batching and Shuffling
batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
assert (
batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
), "batch size must be equal to number of steps * number of envs"
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree_map(
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
)
shuffled_batch = jax.tree_util.tree_map(
lambda x: jnp.take(x, permutation, axis=0), batch
)
# Mini-batch Updates
minibatches = jax.tree_util.tree_map(
lambda x: jnp.reshape(
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
),
shuffled_batch,
)
train_state, total_loss = jax.lax.scan(
_update_minbatch, train_state, minibatches
)
update_state = (train_state, traj_batch, advantages, targets, rng)
return update_state, total_loss
# Updating Training State and Metrics:
update_state = (train_state, traj_batch, advantages, targets, rng)
update_state, loss_info = jax.lax.scan(
_update_epoch, update_state, None, config["UPDATE_EPOCHS"]
)
train_state = update_state[0]
metric = traj_batch.info
rng = update_state[-1]
time_state["timesteps"] = time_state["timesteps"] + 512
metric["time"] = time_state["timesteps"]
metric["loss"] = loss_info
# Debugging mode
if config.get("DEBUG"):
def callback(info):
return_values = info["returned_episode_returns"][
info["returned_episode"]
]
timesteps = (
info["timestep"][info["returned_episode"]] * config["NUM_ENVS"]
)
for t in range(len(timesteps)):
print(
f"global step={timesteps[t]}, episodic return={return_values[t]}"
)
data_log = {
"losses/value_loss": info["loss"][1][0].mean().item(),
"losses/policy_loss": info["loss"][1][1].mean().item(),
"losses/entropy": info["loss"][1][2].mean().item(),
"losses/total_loss": info["loss"][0].mean().item(),
}
if return_values.size > 0:
data_log["misc/episodic_return"] = return_values.mean().item()
wandb.log(data_log, step=info["time"])
jax.debug.callback(callback, metric)
runner_state = (train_state, time_state, env_state, last_obs, rng)
return runner_state, metric
rng, _rng = jax.random.split(rng)
time_state = {"timesteps": jnp.array(0)}
runner_state = (train_state, time_state, env_state, obsv, _rng)
runner_state, metric = jax.lax.scan(
_update_step, runner_state, None, config["NUM_UPDATES"]
)
return {"runner_state": runner_state, "metrics": metric}
return train
if __name__ == "__main__":
config = {
"LR": 2.5e-4,
"NUM_ENVS": 4,
"NUM_STEPS": 128,
"TOTAL_TIMESTEPS": 100000,
"UPDATE_EPOCHS": 4,
"NUM_MINIBATCHES": 4,
"GAMMA": 0.99,
"GAE_LAMBDA": 0.95,
"CLIP_EPS": 0.2,
"ENT_COEF": 0.01,
"VF_COEF": 0.5,
"MAX_GRAD_NORM": 0.5,
"ACTIVATION": "tanh",
"ENV_NAME": "CartPole-v1",
"ANNEAL_LR": True,
"DEBUG": True,
}
run_name = f"test"
wandb_id = wandb.util.generate_id()
run_id = f"{run_name}_{wandb_id}"
wandb.init(
id=run_id,
project="test5",
name=run_name,
mode="online",
)
rng = jax.random.PRNGKey(30)
train_jit = jax.jit(make_train(config))
out = train_jit(rng)
Interesting! I can take a closer look next week. But first, can you check if the weights are significantly different? I want to rule out asynchronous wandb callbacks.
Interesting! I can take a closer look next week. But first, can you check if the weights are significantly different? I want to rule out asynchronous wandb callbacks.
Hi @luchris429! I appreciate your prompt reply, and thanks again for all your awesome work with this repo. To answer your question, I looked at the weights of the first dense layer for another set of runs and they are in fact different at the point I checked (~50k timesteps in with CartPole-v1). The differences aren't drastic, but comparing weights across runs that are the same, the weights are identical so it's clearly the differing weights leading to the differing results.
What is interesting however is that if you look at both these plots and the plots I posted above, you can see that all runs are identical up to a certain number of timesteps before some diverge while others stay the same.
Here's a plot of the runs, you can see how two have identical returns while the third is different:
Now here's snapshots of part of the weights:
Weight snapshot for two runs that are the same:
Mystery solved! Turns out the issue was related to non-determinism that exists by default when training with a GPU. I verified this by training with just CPU for 10 runs, where the results are identical:
I then did GPU training with the following line setting this environment variable: os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
(ref: google/jax#10674)
10 runs of this are also completely identical:
If I remove this flag, then the GPU training becomes non-deterministic pretty quick: