google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`flax.training.checkpoints.save_checkpoint` is blocking

mjsML opened this issue · comments

flax.training.checkpoints.save_checkpoint default behavior is blocking, this becomes a problem as we scale up, specially if we save the checkpoints to GCS or any high latency storage. if we scale up by "napkins math" for a v3-2048 it can be up to 40% of the total wall time (my reference point is ~25 mins on a v3-128)

EDIT: The read is from a US bucket to a TPUv3-8 in eu4 and write to bucket is in eu4 as well.

Consider the below code that simulates the imagenet MLPerf run on a v3-2048:

Steps to reproduce:

from flax.training import checkpoints
from threading import Thread
import time
from tqdm import tqdm

state=None
save_times=80 # save every ~ 10 epochs
compute=5 # ~ 10 epochs of compute on v3-2048 it should be "4.6875" but ~ should be good enough to size the problem. 
readable_bucket_path="gs://flax_public/examples/imagenet/tpu_v3_32"

writeable_bucket="YOUR_BUCKET"
writable_path=f"gs://{writeable_bucket}/perf"

print("Reading parallel GCS","="*30)
state=checkpoints.restore_checkpoint(readable_bucket_path,state,parallel=True)


print("Starting blocking writes to bucket")

total_write=0
# Writing in blocking mode
tic_total = time.time()

for step in tqdm(range(save_times)):
    tic_write = time.time()
    checkpoints.save_checkpoint(writable_path,state,step=step,keep=3)
    toc_write = time.time()
    total_write += (toc_write-tic_write)
    # simulating a compute load of 5 secs because otherwise the writes will be get caught in racing condition. 
    # and sometimes they do anyway because of the non deterministic nature of the GCS access time (which in its own a tell on how slow GCS can be!)
    time.sleep(compute)
toc_total = time.tic = time.time()
blocking_total = toc_total - tic_total
print("Total time on blocked GCS write:",blocking_total)
print("Total write time on blocked GCS write:",total_write)

# =======================================

print("Starting non-blocking writes to bucket")

blocked_write=total_write


total_write=0
ts=[]

# Writing in non-blocking mode

tic_total = time.time()
for step in tqdm(range(save_times,save_times*2)):
    tic_write = time.time()
    thread = Thread(target = checkpoints.save_checkpoint, kwargs= {'ckpt_dir':writable_path,'target':state,'step':step,'keep':3})
    thread.start()
    toc_write = time.time()
    total_write += (toc_write-tic_write)
    # simulating a compute load of 5 secs same as above.
    time.sleep(compute)
    ts.append(thread)
    
    


print("Threaded on GCS before join:",total_write)

for t in tqdm(ts):
    
    t.join()

toc_total= time.time()
non_blocking_total = toc_total - tic_total
non_bocked_write=total_write
print("Time on non-blocked GCS write",non_blocking_total)
print(f"Total speedup ={blocking_total/non_blocking_total}X")
print(f"Total write speedup ={blocked_write/non_bocked_write}X")

yields

Reading parallel GCS ==============================
Starting blocking writes to bucket
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [09:31<00:00,  7.14s/it]
Total time on blocked GCS write: 571.1965103149414
Total write time on blocked GCS write: 170.71105670928955
Starting non-blocking writes to bucket
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [06:40<00:00,  5.01s/it]
Threaded on GCS before join: 0.09925699234008789
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 49947.06it/s]
Time on non-blocked GCS write 400.5715479850769
Total speedup =1.4259537732725367X
Total write speedup =1719.8894776538862X

I'll try to pull on this tonight or tomorrow evening to unblock flax.training.checkpoints.save_checkpoint