`flax.training.checkpoints.save_checkpoint` is blocking
mjsML opened this issue · comments
flax.training.checkpoints.save_checkpoint default behavior is blocking, this becomes a problem as we scale up, specially if we save the checkpoints to GCS or any high latency storage. if we scale up by "napkins math" for a v3-2048 it can be up to 40% of the total wall time (my reference point is ~25 mins on a v3-128)
EDIT: The read is from a US bucket to a TPUv3-8 in eu4 and write to bucket is in eu4 as well.
Consider the below code that simulates the imagenet MLPerf run on a v3-2048:
Steps to reproduce:
from flax.training import checkpoints
from threading import Thread
import time
from tqdm import tqdm
state=None
save_times=80 # save every ~ 10 epochs
compute=5 # ~ 10 epochs of compute on v3-2048 it should be "4.6875" but ~ should be good enough to size the problem.
readable_bucket_path="gs://flax_public/examples/imagenet/tpu_v3_32"
writeable_bucket="YOUR_BUCKET"
writable_path=f"gs://{writeable_bucket}/perf"
print("Reading parallel GCS","="*30)
state=checkpoints.restore_checkpoint(readable_bucket_path,state,parallel=True)
print("Starting blocking writes to bucket")
total_write=0
# Writing in blocking mode
tic_total = time.time()
for step in tqdm(range(save_times)):
tic_write = time.time()
checkpoints.save_checkpoint(writable_path,state,step=step,keep=3)
toc_write = time.time()
total_write += (toc_write-tic_write)
# simulating a compute load of 5 secs because otherwise the writes will be get caught in racing condition.
# and sometimes they do anyway because of the non deterministic nature of the GCS access time (which in its own a tell on how slow GCS can be!)
time.sleep(compute)
toc_total = time.tic = time.time()
blocking_total = toc_total - tic_total
print("Total time on blocked GCS write:",blocking_total)
print("Total write time on blocked GCS write:",total_write)
# =======================================
print("Starting non-blocking writes to bucket")
blocked_write=total_write
total_write=0
ts=[]
# Writing in non-blocking mode
tic_total = time.time()
for step in tqdm(range(save_times,save_times*2)):
tic_write = time.time()
thread = Thread(target = checkpoints.save_checkpoint, kwargs= {'ckpt_dir':writable_path,'target':state,'step':step,'keep':3})
thread.start()
toc_write = time.time()
total_write += (toc_write-tic_write)
# simulating a compute load of 5 secs same as above.
time.sleep(compute)
ts.append(thread)
print("Threaded on GCS before join:",total_write)
for t in tqdm(ts):
t.join()
toc_total= time.time()
non_blocking_total = toc_total - tic_total
non_bocked_write=total_write
print("Time on non-blocked GCS write",non_blocking_total)
print(f"Total speedup ={blocking_total/non_blocking_total}X")
print(f"Total write speedup ={blocked_write/non_bocked_write}X")
yields
Reading parallel GCS ==============================
Starting blocking writes to bucket
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [09:31<00:00, 7.14s/it]
Total time on blocked GCS write: 571.1965103149414
Total write time on blocked GCS write: 170.71105670928955
Starting non-blocking writes to bucket
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [06:40<00:00, 5.01s/it]
Threaded on GCS before join: 0.09925699234008789
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:00<00:00, 49947.06it/s]
Time on non-blocked GCS write 400.5715479850769
Total speedup =1.4259537732725367X
Total write speedup =1719.8894776538862X
I'll try to pull on this tonight or tomorrow evening to unblock flax.training.checkpoints.save_checkpoint