RobertTLange / evosax

Evolution Strategies in JAX 🦎

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Top k fitness in ES Log does not give correct answer

newtonkwan opened this issue · comments

Wonderful library! Looking forward to using it more.

Summary
Calculation of the top k agents using ESLog is incorrect.

Expectation

    es_logging = ESLog(num_dims=2, num_generations=10, top_k=3, maximize=True)
    log = es_logging.initialize()
    x = jnp.array([[1, 2], [2, 4], [4, 6], [6, 7]])
    fitness = jnp.array([1, 2, 3, 4])
    assert jnp.array_equal(log["top_fitness"], jnp.array([4, 3, 2]))
    # Pass: [4. 3. 2.] 

Reality

    es_logging = ESLog(num_dims=2, num_generations=10, top_k=3, maximize=True)
    log = es_logging.initialize()
    x = jnp.array([[1, 2], [2, 4], [4, 6], [6, 7]])
    fitness = jnp.array([1, 2, 3, 4])
    assert jnp.array_equal(log["top_fitness"], jnp.array([4, 3, 2]))
    # AssertionError: [4. 4. 4.]

Fix in PR #23

L:58 of es_logger.py: the second argsort() in top_idx is one parenthesis off.

Good catch! Thank you so much for the kind words and fix.