Issue with `float64` precision
bheijden opened this issue · comments
Hi,
Great work on the toolkit!
I encountered a problem while employing the CMA-ES algorithm. Specifically, one parameter had a standard deviation of approximately 1e-6
, resulting in a variance of 1e-12
. This level of precision exceeded the capabilities of the float32
data type, rendering it inadequate for representing the covariance matrix accurately. Consequently, this limitation caused the generation of samples with zero variance. Although rescaling the covariance matrix was a potential solution, I opted to implement the algorithm using float64
precision as a preliminary measure with:
# this only works on startup!
from jax import config
config.update("jax_enable_x64", True)
However, this approach led to a type mismatch error between float32 and float64, caused by a specific value set as float32 data type, as indicated here:
File "/home/r2ci/rex/sysid/evo.py", line 118, in evo_step
new_state = solver.strategy.tell(x, loss_nonan, state, solver.strategy_params)
File "/home/r2ci/evosax/evosax/strategy.py", line 140, in tell
best_member, best_fitness = get_best_fitness_member(
File "/home/r2ci/evosax/evosax/utils/helpers.py", line 22, in get_best_fitness_member
best_fitness = jax.lax.select(
TypeError: lax.select requires arguments to have the same dtypes, got float32, float64. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).
The fix below to this line resolves the problem. Other strategies may also be affected.
@struct.dataclass
class EvoState:
p_sigma: chex.Array
p_c: chex.Array
C: chex.Array
D: Optional[chex.Array]
B: Optional[chex.Array]
mean: chex.Array
sigma: float
weights: chex.Array
weights_truncated: chex.Array
best_member: chex.Array
# Converts to float, and back to jax.Array so that it is correctly configured as float64.
best_fitness: float = jnp.array(float(jnp.finfo(jnp.float32).max))
# best_fitness: float = jnp.finfo(jnp.float32).max # old
gen_counter: int = 0
MWE:
import jax
import jax.numpy as jnp
# this only works on startup!
from jax import config
config.update("jax_enable_x64", True)
from evosax import CMA_ES
from evosax.problems import BBOBFitness
# Instantiate the evolution strategy instance
strategy = CMA_ES(num_dims=2, popsize=10)
# Get default hyperparameters (e.g. lrate, etc.)
es_params = strategy.default_params
es_params = es_params.replace(init_min= -3, init_max=3)
# Initialize the strategy
rng = jax.random.PRNGKey(0)
state = strategy.initialize(rng, es_params)
# Instantiate helper class for classic evolution strategies benchmarks
evaluator = BBOBFitness("RosenbrockOriginal", num_dims=2)
# Ask for a set of candidate solutions to evaluate
x, state = strategy.ask(rng, state, es_params)
# Evaluate the population members
fitness = evaluator.rollout(rng, x)
# Update the evolution strategy
state = strategy.tell(x, fitness, state, es_params)
state
Closing because duplicate of #45.
On second thought, I do re-open this issue, because the solution I propose could potentially be a permanent fix that does not require extra work from the user.
Feel free to close if you think it's not worth the effort!