FAILED_PRECONDITION: TPU platform already registered for platform hardware version
gobbleturk opened this issue · comments
With various minor/trivial modifcations to train.py I quickly hit the error
FAILED_PRECONDITION: TPU platform already registered for platform hardware version
This happens either on the first random.split if checkpointing is disabled, or inside of the checkpointer if it is enabled.
It looks like calling jax.devices() very early on (in the imports) solves this issue.
Minimal Repro: See https://github.com/google/maxtext/tree/import-fun
The difference between an error and working code is a single print("hello") statement.
Ideally #21 solves this issue by calling jax.device_count
at the start which helps to initialize the environment. We will keep on eye this.