[Dynamo] Make `_run_cached_graph` not blocking for dynamo
JackCaoG opened this issue Β· comments
π Feature
_run_cached_graph
is the API we used to execute the cached graph for our dynamo bridge. Currently this call is blocking, which means it will wait until actual device execution to finish. That is not necessary since we can create data placeholder for the XLATensor we return and let execution to be run in a separate thread. When the execution is finished, we can replace the placeholder with the real data. This is OK since execution thread will hold the device lock which prevent others to access datas.
This is also what we do in LTC to overlap the tracing and execution.
Motivation
When I profile the training using the existing aot_xla_trace_once
bridge, I found there is a significant gap between graph executions. It seems like aot_autograd needs to perform some computation between forward and backward graphs. Those does not need to be blocked on device execution since it is unlikely to require the actual value of the tensor. This should hopefully make our bridge runs a bit faster.
Pitch
in _run_cached_graph
, make device execution happens in a separate thread. We also need to know the output shape of the graph given a hash in order to create the data placeholder.