iree-org / iree-jax

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Dynamic_api_test in jax failing when returning dynamic shape

oliverdutton opened this issue · comments

commented

I am trying to use the (very convenient) option of iree as a jax backend. Running the tests, they seem to be failing when the output shape is dynamic. I'm guessing this test actually works but I'm missing something. The issue is presumably jax forcibly turning the result into a jax array.

Below runs just one of the tests, what am I doing wrong?

!cd ${JAX_REPO_PATH} && \
JAX_ARRAY=0 JAX_PLATFORMS=iree \
pytest -r a --verbosity 1 -s tests/dynamic_api_test.py -k transpose
============================= test session starts ==============================
platform linux -- Python 3.8.13, pytest-6.2.5, py-1.11.0, pluggy-1.0.0 -- /opt/conda/bin/python3.8
cachedir: .pytest_cache
hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/raid/app/oliver/repos/jax/.hypothesis/examples')
rootdir: /raid/app/oliver/repos/jax, configfile: pytest.ini
plugins: anyio-3.6.2, pythonpath-0.7.4, cov-3.0.0, hypothesis-4.50.8
collected 76 items / 75 deselected / 1 selected                                

tests/dynamic_api_test.py::DynamicShapeTest::test_transpose FAILED

=================================== FAILURES ===================================
_______________________ DynamicShapeTest.test_transpose ________________________

self = <dynamic_api_test.DynamicShapeTest testMethod=test_transpose>

    @unittest.skipIf(jtu.device_under_test() != 'iree', 'iree test')
    def test_transpose(self):
      @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'w'},))
      def f(x):  # f32[h, w] -> f32[w, h]
        return x.T
    
>     f(np.ones((3, 5), dtype=np.float32))

tests/dynamic_api_test.py:673: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
jax/_src/api.py:532: in f_jitted
    out_flat = xla.xla_call(
jax/_src/core.py:2132: in bind
    outs = top_trace.process_call(self, fun_, tracers, params)
jax/_src/core.py:794: in process_call
    return primitive.impl(f, *tracers, **params)
jax/_src/dispatch.py:258: in _xla_call_impl
    return compiled_fun(*args)
jax/_src/dispatch.py:916: in _execute_compiled
    return result_handler(env, out_bufs)
jax/_src/dispatch.py:787: in result_handler
    results.append(handler((input_env, results), *bufs))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

sticky_device = None, aval = f32[InDBIdx(val=1),InDBIdx(val=0)]
env = ((3, 5, None), [])
buf = IreeBuffer([[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]])

    def _dynamic_array_result_handler(sticky_device, aval, env, buf):
      in_env, out_env = env or (None, None)
      shape = [in_env[d.val] if type(d) is core.InDBIdx else
               out_env[d.val] if type(d) is core.OutDBIdx else d
               for d in aval.shape]
      if all(type(d) is int for d in shape) and type(aval.dtype) is not core.bint:
>       aval = core.ShapedArray(tuple(shape), buf.dtype)
E       AttributeError: 'IreeBuffer' object has no attribute 'dtype'

jax/_src/dispatch.py:846: AttributeError
=========================== short test summary info ============================
FAILED tests/dynamic_api_test.py::DynamicShapeTest::test_transpose - Attribut...
======================= 1 failed, 75 deselected in 1.15s =======================

The following can reproduce in colab (with jax v0.4.4)

!pip install git+https://github.com/iree-org/iree-jax

import os
os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'

import jax
jax.config.update("jax_dynamic_shapes", True)
jax.config.update("jax_array", False)
jnp = jax.numpy
from functools import partial

@partial(jax.jit, abstracted_axes=({0: 'h', 1: 'w'},), backend='iree')
def f(x):  # f32[h, w] -> f32[w, h]
  return x.T

f(jnp.ones((3, 5), dtype=float))

There seems to be an expectation in Jax python here wrt buffer result types that we don't match here. I'll ask Matt too what is expected.

Thanks for raising this. I think there should be an easy fix; I can take it.

(But the JAX bits being used here are super unpolished and feature coverage is still minimal!)

I think google/jax#14986 should fix. But to be honest I've paged out a lot of context on this! So if something else breaks just let us know.

commented

Perfect, thank you

By the way, It's near the top of my todo list to update dynamic shapes to be compatible with both JAX_JIT_PJIT_API_MERGE=1 and JAX_ARRAY=1. It shouldn't be "hard", but it's nontrivial just because it'll require a big context switch and some time. It never rises to "urgent" like other things because we don't have any dynamic shapes users (or at least I thought we didn't have any until this issue was opened!).

commented

I completely understand, dynamic shape is very experimental. I've been poking around

I was looking at use cases popping foldcomp on GPU by combining with nerfax but compilation times kills it from being competitive due to slightly different length arrays everywhere.

And maybe jax-md can benefit from it in the neighbor list update.

But these are primarily interest projects, hopefully I'll find a meaty business application eventually

Thanks for all the magic of jax

Thanks for the explanation, and the kind words!

I was looking at use cases popping foldcomp on GPU by combining with nerfax but compilation times kills it from being competitive due to slightly different length arrays everywhere.

Wow, this is very interesting. Any chance you could share some representative programs or toy examples, showing what you want to do, or where the compile times are killing you? Maybe we can help!

(this would also be interesting IREE side as we have a WIP pass to dedupe some kernels to dynamic dim variants to reduce compilation times, we've been focussed a bit more.in AOT case but still)

commented

Cool, I'll generate a discussion separately and tag you in it in the next few days with a clear set of code that I'm working on compiling.

Looks like the merge solves those tests, so closing issue