"RuntimeError: Unknown backend iree"
WoongQ opened this issue · comments
Hello, I'm trying to install iree-jax to test GPT-2 on IREE. After running python -m pip install -e '.[test,xla,cpu]' -f https://openxla.github.io/iree/pip-release-links.html
, I built jaxlib from source. However, when I run lit -v tests/
, I get a RuntimeError with the message "Unknown backend iree". This also happens when running models/gpt2/test_jax.py. Did I miss something during the setup process? Your help would be greatly appreciated. I have attached the error log below.
Using pure python filecheck: /home/woongq/jax/bin/filecheck
-- Testing: 5 tests, 5 workers --
FAIL: IREE_JAX :: program/trivial_kernel.py (1 of 5)
******************** TEST 'IREE_JAX :: program/trivial_kernel.py' FAILED ********************
Script:
--
: 'RUN: at line 15'; /home/woongq/jax/bin/python /home/woongq/iree-jax/tests/program/trivial_kernel.py | /home/woongq/jax/bin/filecheck /home/woongq/iree-jax/tests/program/trivial_kernel.py
--
Exit Code: 2
Command Output (stdout):
--
$ ":" "RUN: at line 15"
$ "/home/woongq/jax/bin/python" "/home/woongq/iree-jax/tests/program/trivial_kernel.py"
# command stderr:
WARNING:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002429485321044922 sec
DEBUG:jax._src.xla_bridge:Initializing backend 'cpu'
DEBUG:jax._src.xla_bridge:Backend 'cpu' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'cuda'
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'rocm'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'tpu'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:jax._src.dispatch:Finished tracing + transforming jit(broadcast_in_dim) in 0.0002300739288330078 sec
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
WARNING:jax._src.dispatch:Finished jaxpr to MLIR module conversion jit(broadcast_in_dim) in 0.001964092254638672 sec
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
WARNING:jax._src.dispatch:Finished XLA compilation of jit(broadcast_in_dim) in 0.012798309326171875 sec
WARNING:jax._src.dispatch:Finished tracing + transforming fn for pjit in 0.0004911422729492188 sec
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[3,4]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
WARNING:jax._src.dispatch:Finished jaxpr to MLIR module conversion jit(fn) in 0.0016129016876220703 sec
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
WARNING:jax._src.dispatch:Finished XLA compilation of jit(fn) in 0.0081329345703125 sec
DEBUG:iree_jax:Create new Program subclass: trivial_kernel
DEBUG:root:DEFINE PY_ONLY: _linear = <Exportable Pure Func: <function TrivialKernel._linear at 0x7f91ee93ce50>>
DEBUG:iree_jax:def_global_tree: array _params$0=(3, 4):dtype('float32')
DEBUG:iree_jax:def_global_tree: array _params$1=(3, 4):dtype('float32')
DEBUG:iree_jax:def_global_tree: new tree=Params(x=ConcreteArray(ExportedGlobalArray(@_params$0 : tensor<3x4xf32>), dtype=float32), b=ConcreteArray(ExportedGlobalArray(@_params$1 : tensor<3x4xf32>), dtype=float32))
DEBUG:iree_jax:def_global_tree: array _x$0=(3, 4):dtype('float32')
DEBUG:iree_jax:def_global_tree: new tree=ExportedGlobalArray(@_params$0 : tensor<3x4xf32>)
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/trivial_kernel.py", line 61, in <module>
m = TrivialKernel()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/trivial_kernel.py", line 48, in run
result = self._linear(multiplier, self._params.x, self._params.b)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
donate_argnums) = infer_params_fn(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
in_shardings = out_shardings = _create_sharding_with_device_backend(
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
xb.get_backend(backend).get_default_device_assignment(1)[0])
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
return _get_backend_uncached(platform)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/trivial_kernel.py", line 61, in <module>
m = TrivialKernel()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/trivial_kernel.py", line 48, in run
result = self._linear(multiplier, self._params.x, self._params.b)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree
nanobind: leaked 66 instances!
nanobind: leaked 16 types!
- leaked type "iree._runtime.VmVariantList"
- leaked type "iree._runtime.HalBufferView"
- leaked type "iree._runtime.BufferUsage"
- leaked type "iree._runtime.VmContext"
- leaked type "iree._runtime.MappedMemory"
- leaked type "iree._runtime.ArgumentPacker"
- leaked type "iree._runtime.HalElementType"
- leaked type "iree._runtime.VmRef"
- leaked type "iree._runtime.VmModule"
- leaked type "iree._runtime.HalDevice"
- leaked type "iree._runtime._InvokeStatics"
- ... skipped remainder
nanobind: leaked 78 functions!
- leaked function ""
- leaked function "lookup_function"
- leaked function "__eq__"
- leaked function ""
- leaked function "__iree_vm_type__"
- leaked function "__or__"
- leaked function "__init__"
- leaked function "create_device_by_uri"
- leaked function ""
- leaked function "invoke"
- leaked function "__init__"
- ... skipped remainder
nanobind: this is likely caused by a reference counting issue in the binding code.
error: command failed with exit status: 1
$ "/home/woongq/jax/bin/filecheck" "/home/woongq/iree-jax/tests/program/trivial_kernel.py"
# command output:
CHECK: FileCheck error: '-' is empty.
FileCheck command line: /home/woongq/iree-jax/tests/program/trivial_kernel.py
error: command failed with exit status: 2
--
********************
FAIL: IREE_JAX :: program/fft.py (2 of 5)
******************** TEST 'IREE_JAX :: program/fft.py' FAILED ********************
Script:
--
: 'RUN: at line 15'; /home/woongq/jax/bin/python /home/woongq/iree-jax/tests/program/fft.py | /home/woongq/jax/bin/filecheck /home/woongq/iree-jax/tests/program/fft.py
--
Exit Code: 2
Command Output (stdout):
--
$ ":" "RUN: at line 15"
$ "/home/woongq/jax/bin/python" "/home/woongq/iree-jax/tests/program/fft.py"
# command stderr:
DEBUG:iree_jax:Create new Program subclass: f_f_t
DEBUG:root:DEFINE PY_ONLY: _fft = <Exportable Pure Func: <function FFT._fft at 0x7f92544a2290>>
DEBUG:jax._src.xla_bridge:Initializing backend 'cpu'
DEBUG:jax._src.xla_bridge:Backend 'cpu' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'cuda'
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'rocm'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'tpu'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/fft.py", line 41, in <module>
m = FFT()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/fft.py", line 33, in fft
return self._fft(x)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
donate_argnums) = infer_params_fn(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
in_shardings = out_shardings = _create_sharding_with_device_backend(
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
xb.get_backend(backend).get_default_device_assignment(1)[0])
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
return _get_backend_uncached(platform)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/fft.py", line 41, in <module>
m = FFT()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/fft.py", line 33, in fft
return self._fft(x)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree
error: command failed with exit status: 1
$ "/home/woongq/jax/bin/filecheck" "/home/woongq/iree-jax/tests/program/fft.py"
# command output:
CHECK: FileCheck error: '-' is empty.
FileCheck command line: /home/woongq/iree-jax/tests/program/fft.py
error: command failed with exit status: 2
--
********************
PASS: IREE_JAX :: program/trivial_globals.py (3 of 5)
FAIL: IREE_JAX :: program/duplicate_helper.py (4 of 5)
******************** TEST 'IREE_JAX :: program/duplicate_helper.py' FAILED ********************
Script:
--
: 'RUN: at line 1'; /home/woongq/jax/bin/python /home/woongq/iree-jax/tests/program/duplicate_helper.py
--
Exit Code: 1
Command Output (stdout):
--
$ ":" "RUN: at line 1"
$ "/home/woongq/jax/bin/python" "/home/woongq/iree-jax/tests/program/duplicate_helper.py"
# command stderr:
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/duplicate_helper.py", line 67, in <module>
print(str(Program.get_mlir_module(module)))
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 377, in get_mlir_module
info = Program.get_info(Program._get_instance(m))
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 372, in _get_instance
m = m()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/duplicate_helper.py", line 50, in encode
return mdl._encode(x, y)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
donate_argnums) = infer_params_fn(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
in_shardings = out_shardings = _create_sharding_with_device_backend(
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
xb.get_backend(backend).get_default_device_assignment(1)[0])
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
return _get_backend_uncached(platform)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/duplicate_helper.py", line 67, in <module>
print(str(Program.get_mlir_module(module)))
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 377, in get_mlir_module
info = Program.get_info(Program._get_instance(m))
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 372, in _get_instance
m = m()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/duplicate_helper.py", line 50, in encode
return mdl._encode(x, y)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree
error: command failed with exit status: 1
--
********************
FAIL: IREE_JAX :: program/program_api_test.py (5 of 5)
******************** TEST 'IREE_JAX :: program/program_api_test.py' FAILED ********************
Script:
--
: 'RUN: at line 1'; /home/woongq/jax/bin/python /home/woongq/iree-jax/tests/program/program_api_test.py
--
Exit Code: 1
Command Output (stdout):
--
$ ":" "RUN: at line 1"
$ "/home/woongq/jax/bin/python" "/home/woongq/iree-jax/tests/program/program_api_test.py"
# command stderr:
.DEBUG:iree_jax:Create new Program subclass: hidden
.DEBUG:iree_jax:Create new Program subclass: nullary
DEBUG:iree_jax:Create new Program subclass: unary
.DEBUG:iree_jax:Create new Program subclass: Foobar
.DEBUG:iree_jax:Create new Program subclass: error
.DEBUG:iree_jax:Create new Program subclass: error
.DEBUG:iree_jax:Create new Program subclass: error
.DEBUG:iree_jax:Create new Program subclass: error
.DEBUG:iree_jax:Create new Program subclass: global
.DEBUG:iree_jax:Create new Program subclass: my_subclass
./home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py:288: DeprecationWarning: backend and device argument on jit is deprecated. You can use a `jax.sharding.Mesh` context manager or device_put the arguments before passing them to `jit`. Please see https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html for more information.
warnings.warn(
DEBUG:iree_jax:Create new Program subclass: iree_jax
DEBUG:root:DEFINE PY_ONLY: _f = <Exportable Pure Func: <function ProgramApiTest.test_value_tracing_with_flax_frozen_dict.<locals>.IreeJaxProgram._f at 0x7f673b4e7760>>
DEBUG:jax._src.xla_bridge:Initializing backend 'cpu'
DEBUG:jax._src.xla_bridge:Backend 'cpu' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'cuda'
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'rocm'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'tpu'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
EDEBUG:iree_jax:Create new Program subclass: iree_jax
DEBUG:root:DEFINE PY_ONLY: _f = <Exportable Pure Func: <function ProgramApiTest.test_value_tracing_with_list.<locals>.IreeJaxProgram._f at 0x7f673b5384c0>>
E
======================================================================
ERROR: test_value_tracing_with_flax_frozen_dict (__main__.ProgramApiTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 163, in <module>
unittest.main()
File "/usr/lib/python3.10/unittest/main.py", line 101, in __init__
self.runTests()
File "/usr/lib/python3.10/unittest/main.py", line 271, in runTests
self.result = testRunner.run(self.test)
File "/usr/lib/python3.10/unittest/runner.py", line 184, in run
test(result)
File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
return self.run(*args, **kwds)
File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
test(result)
File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
return self.run(*args, **kwds)
File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
test(result)
File "/usr/lib/python3.10/unittest/case.py", line 650, in __call__
return self.run(*args, **kwds)
File "/usr/lib/python3.10/unittest/case.py", line 591, in run
self._callTestMethod(testMethod)
File "/usr/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
method()
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 145, in test_value_tracing_with_flax_frozen_dict
IreeJaxProgram()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 139, in f
return self._f(x)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
donate_argnums) = infer_params_fn(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
in_shardings = out_shardings = _create_sharding_with_device_backend(
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
xb.get_backend(backend).get_default_device_assignment(1)[0])
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
return _get_backend_uncached(platform)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 145, in test_value_tracing_with_flax_frozen_dict
IreeJaxProgram()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 139, in f
return self._f(x)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree
======================================================================
ERROR: test_value_tracing_with_list (__main__.ProgramApiTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 163, in <module>
unittest.main()
File "/usr/lib/python3.10/unittest/main.py", line 101, in __init__
self.runTests()
File "/usr/lib/python3.10/unittest/main.py", line 271, in runTests
self.result = testRunner.run(self.test)
File "/usr/lib/python3.10/unittest/runner.py", line 184, in run
test(result)
File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
return self.run(*args, **kwds)
File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
test(result)
File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
return self.run(*args, **kwds)
File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
test(result)
File "/usr/lib/python3.10/unittest/case.py", line 650, in __call__
return self.run(*args, **kwds)
File "/usr/lib/python3.10/unittest/case.py", line 591, in run
self._callTestMethod(testMethod)
File "/usr/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
method()
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 159, in test_value_tracing_with_list
IreeJaxProgram()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 153, in f
return self._f(x)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
donate_argnums) = infer_params_fn(*args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
in_shardings = out_shardings = _create_sharding_with_device_backend(
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
xb.get_backend(backend).get_default_device_assignment(1)[0])
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
return _get_backend_uncached(platform)
File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 159, in test_value_tracing_with_list
IreeJaxProgram()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
export_function()
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
return_py_value = f(*argument_py_tree)
File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 153, in f
return self._f(x)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree
----------------------------------------------------------------------
Ran 12 tests in 0.035s
FAILED (errors=2)
error: command failed with exit status: 1
--
********************
********************
Failed Tests (4):
IREE_JAX :: program/duplicate_helper.py
IREE_JAX :: program/fft.py
IREE_JAX :: program/program_api_test.py
IREE_JAX :: program/trivial_kernel.py
Testing Time: 0.73s
Passed: 1
Failed: 4
https://github.com/openxla/openxla-pjrt-plugin is the right way to use JAX+IREE.
https://github.com/openxla/openxla-pjrt-plugin is the right way to use JAX+IREE.
The PJRT plugin is one way to use JAX+IREE, mostly for JIT scenarios from Python. This repository is another way, with a focus on AOT scenarios outside of Python. See https://openxla.github.io/iree/guides/ml-frameworks/jax/
Did I miss something during the setup process?
Possibly. You can see what https://github.com/iree-org/iree-jax/blob/main/.github/workflows/test_gpt2_model.yaml is doing... that runs nightly at https://github.com/iree-org/iree-jax/actions.