kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

training stuck at validation step 1

Selimonder opened this issue · comments

commented

Hello,

First of all, thank you for the great finetune guide. I followed this guide through and attempted fine-tuning GPT-J with 30-40MBs of a small dataset.

However, I am stuck at device_train.py step (which is 12th step).

The compiling of the train, eval, and network passes. Also, the first weights are written on the bucket.

it seems like the code is freezing at

out = network.eval(inputs) line under eval_step function. During compiling, data is passing through eval_step but when actual training starts it freezes.

Did anyone stumble upon a similar issue?

Hi! I've trained GPT-J models in the past, but for some reason I'm now seeing this too. Did you manage to solve it?

Hey @versae, @Selimonder! Did you manage somehow to resolve this issue? Facing the same problem rn

@rinapch for me the key was to select the alpha version when creating the TPU. Stable releases seem to break the implementation, not sure why.

Thanks a lot, @versae! It really did help 🥳