The minimalist Program-derived class
wangkuiyi opened this issue · comments
Yi Wang commented
I am curious to know how I could derive a Python class from iree.jax.Program
that compiles into the simple_mul example.
I tried the following.
from iree.jax import Program
class TrivialKernel(Program):
def mul(self, xx, yy):
return xx * yy
m = TrivialKernel()
# print(Program.get_mlir_module(m))
Unfortunately, running it gives the following error.
File "/Users/y/w/iree-ios/iree-jax/iree/jax/program_api.py", line 224, in def_export_function
raise TypeError(
TypeError: export function 'mul' missing default value annotation for parameter 'xx'
I guess that the IREE compiler wants to know the size of xx
, so I changed the program into the following.
from iree.jax import Program
import jax.numpy as jnp
x = jnp.ones((4, 4), jnp.float32)
y = jnp.ones((4, 4), jnp.float32)
class TrivialKernel(Program):
def mul(self, xx=Program.like(x), yy=Program.like(y)):
return xx * yy
m = TrivialKernel()
# print(Program.get_mlir_module(m))
However, it gives the following error.
File "/Users/y/w/iree-ios/jax-samples/2.py", line 9, in mul
return xx * yy
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/numpy/lib/mixins.py", line 21, in func
return ufunc(self, other)
ValueError: object __array__ method not producing an array
Yi Wang commented
I successfully compiled the following file (2.py
).
from iree.jax import Program
import jax.numpy as jnp
x = jnp.ones((4, 4), jnp.float32)
class TrivialKernel(Program):
def run(self, x=Program.like(x)):
return self.mul(x)
@Program.kernel
def mul(x):
return x * x
m = TrivialKernel()
print(Program.get_mlir_module(m))
with the following command
python jax-samples/2.py | \
./build/compiler/install/bin/iree-compile \
--iree-input-type=mhlo \
--iree-hal-target-backends=vmvx \
- > /tmp/a.vmfb
And got the following MLIR, which looks very close to simple_mul.mlir
.
module @trivial_kernel {
func.func @run(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
%0 = call @jit_mul$main(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
func.func private @jit_mul$main(%arg0: tensor<4x4xf32> {mhlo.sharding = ""}) -> tensor<4x4xf32> {
%0 = stablehlo.multiply %arg0, %arg0 : tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
}
It seems that
- the entrypoint function (
run
in this example) cannot be@Program.kernel
-decorated, and - matrix operations (
*
in this example) must be in a function that is@Program.kernel
-decorated.