google / maxtext

A simple, performant and scalable Jax LLM!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Reproducing pure computation TFLOPs

prrathi opened this issue · comments

Hi, I'm trying to measure attained TFLOP/s when training on v5-tpu-8. I believe this is done in MaxText/train.py, from lines 491 to 534, and with the current main branch as is I get numbers around 200-300 TFLOPs. But the current calculation sets last_step_completion = new_time on line 505, and thus the extra overhead of calls to record_scalar_metrics and write_metrics is included (assuming no checkpointing). To remove this overhead, I moved line 505 to below line 531, and changed to last_step_completion = datetime.datetime.now() to isolate just the forward and backward pass. I also made the following changes to MaxText/configs/base.yml:

enable_checkpointing: False
async_checkpointing: False
dataset_name: ''
eval_dataset_name: ''
dataset_type: synthetic
max_corpus_chars: 10_000
steps: 10
enable_dropout: False
enable_data_shuffling: False

then when I run train.py with this new base.yml I get very small times and high throughput numbers, around 0.001 seconds and 100000-200000 TFLOP/s/device.

Is there something wrong with the small changes I made that's resulting in these measurements? Is there some other suggested way to remove any overhead and only capture the computation TFLOP/s?

Note: I also tried increasing steps to 10000 but observed similar FLOP/s numbers through all the steps

Interesting question! Jax is lazy by default so if there is no effort to extract the values output by the computation, jax will keep running ahead (and enqueuing more computations!).

So in this context, Jax is executing your lines of python for step 2 and 3 and so on before finishing the execution of step 1.

Emitting telemetry isn't taking time -- it is serving as a gate for the python execution!

Hope that helps!

@rwitten Got it that makes sense thanks! Just curious, is there a specific part or line of code in the telemetry logic that gates the execution?

Writing to tensorboard at some point requires outputs which serves as a gate.