pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)

Home Page:https://pytorch.org/xla

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Dynamo] Make `_run_cached_graph` not blocking for dynamo

JackCaoG opened this issue Β· comments

πŸš€ Feature

_run_cached_graph is the API we used to execute the cached graph for our dynamo bridge. Currently this call is blocking, which means it will wait until actual device execution to finish. That is not necessary since we can create data placeholder for the XLATensor we return and let execution to be run in a separate thread. When the execution is finished, we can replace the placeholder with the real data. This is OK since execution thread will hold the device lock which prevent others to access datas.
This is also what we do in LTC to overlap the tracing and execution.

Motivation

When I profile the training using the existing aot_xla_trace_once bridge, I found there is a significant gap between graph executions. It seems like aot_autograd needs to perform some computation between forward and backward graphs. Those does not need to be blocked on device execution since it is unlikely to require the actual value of the tensor. This should hopefully make our bridge runs a bit faster.

Pitch

in _run_cached_graph, make device execution happens in a separate thread. We also need to know the output shape of the graph given a hash in order to create the data placeholder.

@alanwaketan @shunting314 @will-cromar