google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Integrating TVM into JAX

chaoming0625 opened this issue · comments

Please:

  • Check for duplicate requests.
  • Describe your goal, and if possible provide a code snippet with a motivating example.

TVM can generate highly optimized operators. Is it possible to integrate the optimized operator of TVM into jax's jit system?

Do you have a particular use case in mind?

In general: sure, we could either use TVM in place of XLA, or we could use TVM to generate individual kernels inside an otherwise XLA-compiled program. For the first approach JAX has some early support for plugging in alternate compilers and runtimes, and for the second case, it's possible to mix XLA-compiled code and other code via mechanisms like dlpack and XLA CustomCall operators. But it would probably be a reasonably large project.

I suspect this is probably in the "contributions welcome" category, but I'd be interested to know if there is a particular program or use case that motivates the question.

My understanding is TVM supported more hardware accelerators than XLA. It will be interesting to run compiled JAX programs on something like FPGA.