iree-org / iree-jax

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

The minimalist Program-derived class

wangkuiyi opened this issue · comments

I am curious to know how I could derive a Python class from iree.jax.Program that compiles into the simple_mul example.

I tried the following.

from iree.jax import Program

class TrivialKernel(Program):
  def mul(self, xx, yy):
    return xx * yy

m = TrivialKernel()
# print(Program.get_mlir_module(m))

Unfortunately, running it gives the following error.

  File "/Users/y/w/iree-ios/iree-jax/iree/jax/program_api.py", line 224, in def_export_function
    raise TypeError(
TypeError: export function 'mul' missing default value annotation for parameter 'xx'

I guess that the IREE compiler wants to know the size of xx, so I changed the program into the following.

from iree.jax import Program
import jax.numpy as jnp

x = jnp.ones((4, 4), jnp.float32)
y = jnp.ones((4, 4), jnp.float32)

class TrivialKernel(Program):
  def mul(self, xx=Program.like(x), yy=Program.like(y)):
    return xx * yy

m = TrivialKernel()
# print(Program.get_mlir_module(m))

However, it gives the following error.

  File "/Users/y/w/iree-ios/jax-samples/2.py", line 9, in mul
    return xx * yy
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/numpy/lib/mixins.py", line 21, in func
    return ufunc(self, other)
ValueError: object __array__ method not producing an array

I successfully compiled the following file (2.py).

from iree.jax import Program
import jax.numpy as jnp

x = jnp.ones((4, 4), jnp.float32)

class TrivialKernel(Program):
  def run(self, x=Program.like(x)):
    return self.mul(x)
    
  @Program.kernel
  def mul(x):
    return x * x

m = TrivialKernel()
print(Program.get_mlir_module(m))

with the following command

python jax-samples/2.py | \
   ./build/compiler/install/bin/iree-compile \
   --iree-input-type=mhlo \
   --iree-hal-target-backends=vmvx   \
   - > /tmp/a.vmfb

And got the following MLIR, which looks very close to simple_mul.mlir.

module @trivial_kernel {
  func.func @run(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
    %0 = call @jit_mul$main(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
    return %0 : tensor<4x4xf32>
  }
  func.func private @jit_mul$main(%arg0: tensor<4x4xf32> {mhlo.sharding = ""}) -> tensor<4x4xf32> {
    %0 = stablehlo.multiply %arg0, %arg0 : tensor<4x4xf32>
    return %0 : tensor<4x4xf32>
  }
}

It seems that

  • the entrypoint function (run in this example) cannot be @Program.kernel-decorated, and
  • matrix operations (* in this example) must be in a function that is @Program.kernel-decorated.