iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Torch-mlir input conversion pipeline incompatible with `--iree-execution-model=inline-dynamic`

zero9178 opened this issue · comments

Given a simple example such as:

class CompiledSimple(aot.CompiledModule):
    def forward(self, x=aot.AbstractTensor(1)):
        return aot.jittable(lambda x: x + x)(x)

exported = aot.export(CompiledSimple)
exported.print_readable()

imported using iree-turbine yields the follow IR:

module @compiled_simple {
  func.func @forward(%arg0: tensor<1xf32>) -> tensor<1xf32> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} {
    %0 = torch_c.from_builtin_tensor %arg0 : tensor<1xf32> -> !torch.vtensor<[1],f32>
    %1 = call @"<lambda>"(%0) : (!torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
    %2 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[1],f32> -> tensor<1xf32>
    return %2 : tensor<1xf32>
  }
  func.func private @"<lambda>"(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> {
    %int1 = torch.constant.int 1
    %0 = torch.aten.add.Tensor %arg0, %arg0, %int1 : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
    return %0 : !torch.vtensor<[1],f32>
  }
}

This compiles perfectly fine with e.g. iree-compile test.mlir --iree-hal-target-backends=llvm-cpu -o /dev/null but fails when trying to use hal_inline via iree-compile test.mlir --iree-hal-target-backends=llvm-cpu -o /dev/null --iree-execution-model=inline-dynamic with the following error:

test.mlir:2:3: error: failed to legalize operation 'hal.devices.get' that was explicitly marked illegal
  func.func @forward(%arg0: tensor<1xf32>) -> tensor<1xf32> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} {
  ^
test.mlir:2:3: note: see current operation: %2 = "hal.devices.get"(%1) : (index) -> !hal.device
test.mlir:1:1: error: conversion to the hal_inline + hal_loader dialects failed
module @compiled_simple {
^
test.mlir:1:1: note: see current operation: 
"builtin.module"() <{sym_name = "compiled_simple"}> ({
  "hal.executable"() ({
    "hal.executable.variant"() ({
      "hal.executable.export"() ({
      ^bb0(%arg0: !hal.device):
        %0 = "arith.constant"() <{value = 1 : index}> : () -> index
        "hal.return"(%0, %0, %0) : (index, index, index) -> ()
      }) {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>], layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>, ordinal = 0 : index, sym_name = "forward$async_dispatch_0_generic", translation_info = #iree_codegen.translation_info<CPUDefault>} : () -> ()
      "builtin.module"() ({
        "llvm.func"() <{CConv = #llvm.cconv<ccc>, arg_attrs = [{llvm.align = 16 : i64, llvm.noalias}, {llvm.align = 16 : i64, llvm.noalias}, {llvm.align = 16 : i64, llvm.noalias}], function_type = !llvm.func<i32 (ptr, ptr, ptr)>, linkage = #llvm.linkage<external>, sym_name = "forward$async_dispatch_0_generic", visibility_ = 0 : i64}> ({
        ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr):
          %0 = "llvm.mlir.constant"() <{value = 0 : i32}> : () -> i32
          %1 = "llvm.mlir.constant"() <{value = 63 : index}> : () -> i64
          %2 = "llvm.mlir.constant"() <{value = 0 : index}> : () -> i64
          %3 = "llvm.load"(%arg1) <{ordering = 0 : i64}> : (!llvm.ptr) -> !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>
          %4 = "llvm.extractvalue"(%3) <{position = array<i64: 10>}> : (!llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>) -> !llvm.ptr
          %5 = "llvm.load"(%4) <{ordering = 0 : i64}> : (!llvm.ptr) -> !llvm.ptr
          %6 = "llvm.ptrtoint"(%5) : (!llvm.ptr) -> i64
          %7 = "llvm.and"(%6, %1) : (i64, i64) -> i64
          %8 = "llvm.icmp"(%7, %2) <{predicate = 0 : i64}> : (i64, i64) -> i1
          "llvm.intr.assume"(%8) : (i1) -> ()
          %9 = "llvm.load"(%arg1) <{ordering = 0 : i64}> : (!llvm.ptr) -> !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>
          %10 = "llvm.extractvalue"(%9) <{position = array<i64: 10>}> : (!llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>) -> !llvm.ptr
          %11 = "llvm.getelementptr"(%10) <{elem_type = !llvm.ptr, rawConstantIndices = array<i32: 1>}> : (!llvm.ptr) -> !llvm.ptr
          %12 = "llvm.load"(%11) <{ordering = 0 : i64}> : (!llvm.ptr) -> !llvm.ptr
          %13 = "llvm.ptrtoint"(%12) : (!llvm.ptr) -> i64
          %14 = "llvm.and"(%13, %1) : (i64, i64) -> i64
          %15 = "llvm.icmp"(%14, %2) <{predicate = 0 : i64}> : (i64, i64) -> i1
          "llvm.intr.assume"(%15) : (i1) -> ()
          %16 = "llvm.load"(%5) <{ordering = 0 : i64}> : (!llvm.ptr) -> f32
          %17 = "llvm.fadd"(%16, %16) <{fastmathFlags = #llvm.fastmath<contract>}> : (f32, f32) -> f32
          "llvm.store"(%17, %12) <{ordering = 0 : i64}> : (f32, !llvm.ptr) -> ()
          "llvm.return"(%0) : (i32) -> ()
        }) : () -> ()
      }) {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-unknown-eabi-elf"} : () -> ()
      "hal.executable.variant_end"() : () -> ()
    }) {sym_name = "embedded_elf_x86_64", target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "generic", cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 16 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>} : () -> ()
    "hal.executable_end"() : () -> ()
  }) {sym_name = "forward$async_dispatch_0", sym_visibility = "private"} : () -> ()
  "util.func"() <{function_type = (!hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view, inlining_policy = #util.inline.never, sym_name = "forward$async"}> ({
  ^bb0(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence):
    %0 = "arith.constant"() <{value = 4 : index}> : () -> index
    %1 = "arith.constant"() <{value = 0 : index}> : () -> index
    %2 = "arith.constant"() <{value = 1 : index}> : () -> index
    %3 = "hal.element_type"() {type = f32} : () -> i32
    %4 = "hal.encoding_type"() : () -> i32
    "hal.buffer_view.assert"(%arg0, %3, %4, %2) {message = "tensor"} : (!hal.buffer_view, i32, i32, index) -> ()
    %5 = "stream.tensor.import"(%arg0, %0) <{result_encoding = tensor<1xf32>}> : (!hal.buffer_view, index) -> !stream.resource<external>
    %6 = "stream.timepoint.import"(%arg1) : (!hal.fence) -> !stream.timepoint
    %7:2 = "stream.resource.alloca"(%0, %6) : (index, !stream.timepoint) -> (!stream.resource<external>, !stream.timepoint)
    %8 = "stream.cmd.execute"(%5, %7#0, %0, %0, %7#1) <{operandSegmentSizes = array<i32: 2, 2, 1>}> ({
    ^bb0(%arg3: !stream.resource<external>, %arg4: !stream.resource<external>):
      "stream.cmd.dispatch"(%arg3, %arg4, %0, %0, %1, %1, %0, %0) <{entry_points = [@forward$async_dispatch_0::@embedded_elf_x86_64::@forward$async_dispatch_0_generic], operandSegmentSizes = array<i32: 0, 0, 2, 2, 2, 2>, resource_accesses = [1 : i32, 2 : i32]}> : (!stream.resource<external>, !stream.resource<external>, index, index, index, index, index, index) -> ()
      "stream.yield"() : () -> ()
    }) : (!stream.resource<external>, !stream.resource<external>, index, index, !stream.timepoint) -> !stream.timepoint
    "stream.timepoint.chain_external"(%8, %arg2) : (!stream.timepoint, !hal.fence) -> ()
    %9 = "stream.tensor.export"(%7#0, %0) <{source_encoding = tensor<1xf32>}> : (!stream.resource<external>, index) -> !hal.buffer_view
    "util.return"(%9) : (!hal.buffer_view) -> ()
  }) {iree.abi.model = "coarse-fences", iree.abi.stub} : () -> ()
  "util.func"() <{function_type = (!hal.buffer_view) -> !hal.buffer_view, sym_name = "forward"}> ({
  ^bb0(%arg0: !hal.buffer_view):
    %0 = "arith.constant"() <{value = -1 : i32}> : () -> i32
    %1 = "arith.constant"() <{value = 0 : index}> : () -> index
    %2 = "hal.devices.get"(%1) : (index) -> !hal.device
    %3 = "util.null"() : () -> !hal.fence
    %4 = "hal.fence.create"(%2) {flags = 0 : i32} : (!hal.device) -> !hal.fence
    %5 = "util.call"(%arg0, %3, %4) <{callee = @forward$async}> : (!hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view
    %6 = "hal.fence.await"(%0, %4) : (i32, !hal.fence) -> i32
    "util.return"(%5) : (!hal.buffer_view) -> ()
  }) {iree.abi.stub} : () -> ()
}) {hal.device.targets = [#hal.device.target<"llvm-cpu", [#hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "generic", cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 16 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>]>]} : () -> ()

As far as I can tell, the issue stems from the existence of the hal.fence.create operation which necessitates the hal.device.get operation which the conversion pass to hal_inline does not expect as input. Rather it expects the stream dialect and friends which it'd then convert to a synchronous inline implementation that does not need fences that are usually created by StreamToHAL conversion.

The reason the hal.fence.create operation is in the IR is due to the torch-iree-func-conversion pass:

Value timeoutMillis = rewriter.create<arith::ConstantIntOp>(loc, -1, 32);
Value device = IREE::HAL::DeviceType::resolveAny(loc, rewriter);
Value waitFence = rewriter.create<IREE::Util::NullOp>(
loc, rewriter.getType<IREE::HAL::FenceType>());
Value signalFence = rewriter.create<IREE::HAL::FenceCreateOp>(
loc, rewriter.getType<IREE::HAL::FenceType>(), device,
IREE::HAL::FenceFlagBitfield::None);

It creates a synchronous wrapper using manually created fences. This is seemingly fundamentally incompatible with the conversion pass to hal_inline unless the latter were to be capable of removing fences (which does not seem trivial in general).

Yeah, I think the current FuncConversion is only designed to work with the full HAL. To use the inline HAL we'll need a mode for the conversion to switch to only generating the synchronous function without any of the full HAL types (fences/devices/etc).