AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

FAILED_PRECONDITION: TPU platform already registered for platform hardware version

gobbleturk opened this issue · comments

With various minor/trivial modifcations to train.py I quickly hit the error

FAILED_PRECONDITION: TPU platform already registered for platform hardware version

This happens either on the first random.split if checkpointing is disabled, or inside of the checkpointer if it is enabled.

It looks like calling jax.devices() very early on (in the imports) solves this issue.

Minimal Repro: See https://github.com/google/maxtext/tree/import-fun
The difference between an error and working code is a single print("hello") statement.

Ideally #21 solves this issue by calling jax.device_count at the start which helps to initialize the environment. We will keep on eye this.