Torch-mlir input conversion pipeline incompatible with `--iree-execution-model=inline-dynamic`
zero9178 opened this issue · comments
Given a simple example such as:
class CompiledSimple(aot.CompiledModule):
def forward(self, x=aot.AbstractTensor(1)):
return aot.jittable(lambda x: x + x)(x)
exported = aot.export(CompiledSimple)
exported.print_readable()
imported using iree-turbine
yields the follow IR:
module @compiled_simple {
func.func @forward(%arg0: tensor<1xf32>) -> tensor<1xf32> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<1xf32> -> !torch.vtensor<[1],f32>
%1 = call @"<lambda>"(%0) : (!torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
%2 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[1],f32> -> tensor<1xf32>
return %2 : tensor<1xf32>
}
func.func private @"<lambda>"(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> {
%int1 = torch.constant.int 1
%0 = torch.aten.add.Tensor %arg0, %arg0, %int1 : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
return %0 : !torch.vtensor<[1],f32>
}
}
This compiles perfectly fine with e.g. iree-compile test.mlir --iree-hal-target-backends=llvm-cpu -o /dev/null
but fails when trying to use hal_inline
via iree-compile test.mlir --iree-hal-target-backends=llvm-cpu -o /dev/null --iree-execution-model=inline-dynamic
with the following error:
test.mlir:2:3: error: failed to legalize operation 'hal.devices.get' that was explicitly marked illegal
func.func @forward(%arg0: tensor<1xf32>) -> tensor<1xf32> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} {
^
test.mlir:2:3: note: see current operation: %2 = "hal.devices.get"(%1) : (index) -> !hal.device
test.mlir:1:1: error: conversion to the hal_inline + hal_loader dialects failed
module @compiled_simple {
^
test.mlir:1:1: note: see current operation:
"builtin.module"() <{sym_name = "compiled_simple"}> ({
"hal.executable"() ({
"hal.executable.variant"() ({
"hal.executable.export"() ({
^bb0(%arg0: !hal.device):
%0 = "arith.constant"() <{value = 1 : index}> : () -> index
"hal.return"(%0, %0, %0) : (index, index, index) -> ()
}) {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>], layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>, ordinal = 0 : index, sym_name = "forward$async_dispatch_0_generic", translation_info = #iree_codegen.translation_info<CPUDefault>} : () -> ()
"builtin.module"() ({
"llvm.func"() <{CConv = #llvm.cconv<ccc>, arg_attrs = [{llvm.align = 16 : i64, llvm.noalias}, {llvm.align = 16 : i64, llvm.noalias}, {llvm.align = 16 : i64, llvm.noalias}], function_type = !llvm.func<i32 (ptr, ptr, ptr)>, linkage = #llvm.linkage<external>, sym_name = "forward$async_dispatch_0_generic", visibility_ = 0 : i64}> ({
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr):
%0 = "llvm.mlir.constant"() <{value = 0 : i32}> : () -> i32
%1 = "llvm.mlir.constant"() <{value = 63 : index}> : () -> i64
%2 = "llvm.mlir.constant"() <{value = 0 : index}> : () -> i64
%3 = "llvm.load"(%arg1) <{ordering = 0 : i64}> : (!llvm.ptr) -> !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>
%4 = "llvm.extractvalue"(%3) <{position = array<i64: 10>}> : (!llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>) -> !llvm.ptr
%5 = "llvm.load"(%4) <{ordering = 0 : i64}> : (!llvm.ptr) -> !llvm.ptr
%6 = "llvm.ptrtoint"(%5) : (!llvm.ptr) -> i64
%7 = "llvm.and"(%6, %1) : (i64, i64) -> i64
%8 = "llvm.icmp"(%7, %2) <{predicate = 0 : i64}> : (i64, i64) -> i1
"llvm.intr.assume"(%8) : (i1) -> ()
%9 = "llvm.load"(%arg1) <{ordering = 0 : i64}> : (!llvm.ptr) -> !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>
%10 = "llvm.extractvalue"(%9) <{position = array<i64: 10>}> : (!llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>) -> !llvm.ptr
%11 = "llvm.getelementptr"(%10) <{elem_type = !llvm.ptr, rawConstantIndices = array<i32: 1>}> : (!llvm.ptr) -> !llvm.ptr
%12 = "llvm.load"(%11) <{ordering = 0 : i64}> : (!llvm.ptr) -> !llvm.ptr
%13 = "llvm.ptrtoint"(%12) : (!llvm.ptr) -> i64
%14 = "llvm.and"(%13, %1) : (i64, i64) -> i64
%15 = "llvm.icmp"(%14, %2) <{predicate = 0 : i64}> : (i64, i64) -> i1
"llvm.intr.assume"(%15) : (i1) -> ()
%16 = "llvm.load"(%5) <{ordering = 0 : i64}> : (!llvm.ptr) -> f32
%17 = "llvm.fadd"(%16, %16) <{fastmathFlags = #llvm.fastmath<contract>}> : (f32, f32) -> f32
"llvm.store"(%17, %12) <{ordering = 0 : i64}> : (f32, !llvm.ptr) -> ()
"llvm.return"(%0) : (i32) -> ()
}) : () -> ()
}) {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-unknown-eabi-elf"} : () -> ()
"hal.executable.variant_end"() : () -> ()
}) {sym_name = "embedded_elf_x86_64", target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "generic", cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 16 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>} : () -> ()
"hal.executable_end"() : () -> ()
}) {sym_name = "forward$async_dispatch_0", sym_visibility = "private"} : () -> ()
"util.func"() <{function_type = (!hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view, inlining_policy = #util.inline.never, sym_name = "forward$async"}> ({
^bb0(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence):
%0 = "arith.constant"() <{value = 4 : index}> : () -> index
%1 = "arith.constant"() <{value = 0 : index}> : () -> index
%2 = "arith.constant"() <{value = 1 : index}> : () -> index
%3 = "hal.element_type"() {type = f32} : () -> i32
%4 = "hal.encoding_type"() : () -> i32
"hal.buffer_view.assert"(%arg0, %3, %4, %2) {message = "tensor"} : (!hal.buffer_view, i32, i32, index) -> ()
%5 = "stream.tensor.import"(%arg0, %0) <{result_encoding = tensor<1xf32>}> : (!hal.buffer_view, index) -> !stream.resource<external>
%6 = "stream.timepoint.import"(%arg1) : (!hal.fence) -> !stream.timepoint
%7:2 = "stream.resource.alloca"(%0, %6) : (index, !stream.timepoint) -> (!stream.resource<external>, !stream.timepoint)
%8 = "stream.cmd.execute"(%5, %7#0, %0, %0, %7#1) <{operandSegmentSizes = array<i32: 2, 2, 1>}> ({
^bb0(%arg3: !stream.resource<external>, %arg4: !stream.resource<external>):
"stream.cmd.dispatch"(%arg3, %arg4, %0, %0, %1, %1, %0, %0) <{entry_points = [@forward$async_dispatch_0::@embedded_elf_x86_64::@forward$async_dispatch_0_generic], operandSegmentSizes = array<i32: 0, 0, 2, 2, 2, 2>, resource_accesses = [1 : i32, 2 : i32]}> : (!stream.resource<external>, !stream.resource<external>, index, index, index, index, index, index) -> ()
"stream.yield"() : () -> ()
}) : (!stream.resource<external>, !stream.resource<external>, index, index, !stream.timepoint) -> !stream.timepoint
"stream.timepoint.chain_external"(%8, %arg2) : (!stream.timepoint, !hal.fence) -> ()
%9 = "stream.tensor.export"(%7#0, %0) <{source_encoding = tensor<1xf32>}> : (!stream.resource<external>, index) -> !hal.buffer_view
"util.return"(%9) : (!hal.buffer_view) -> ()
}) {iree.abi.model = "coarse-fences", iree.abi.stub} : () -> ()
"util.func"() <{function_type = (!hal.buffer_view) -> !hal.buffer_view, sym_name = "forward"}> ({
^bb0(%arg0: !hal.buffer_view):
%0 = "arith.constant"() <{value = -1 : i32}> : () -> i32
%1 = "arith.constant"() <{value = 0 : index}> : () -> index
%2 = "hal.devices.get"(%1) : (index) -> !hal.device
%3 = "util.null"() : () -> !hal.fence
%4 = "hal.fence.create"(%2) {flags = 0 : i32} : (!hal.device) -> !hal.fence
%5 = "util.call"(%arg0, %3, %4) <{callee = @forward$async}> : (!hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view
%6 = "hal.fence.await"(%0, %4) : (i32, !hal.fence) -> i32
"util.return"(%5) : (!hal.buffer_view) -> ()
}) {iree.abi.stub} : () -> ()
}) {hal.device.targets = [#hal.device.target<"llvm-cpu", [#hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "generic", cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 16 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>]>]} : () -> ()
As far as I can tell, the issue stems from the existence of the hal.fence.create
operation which necessitates the hal.device.get
operation which the conversion pass to hal_inline
does not expect as input. Rather it expects the stream dialect and friends which it'd then convert to a synchronous inline implementation that does not need fences that are usually created by StreamToHAL
conversion.
The reason the hal.fence.create
operation is in the IR is due to the torch-iree-func-conversion
pass:
iree/compiler/plugins/input/Torch/InputConversion/FuncConversion.cpp
Lines 472 to 478 in d2dd9e2
It creates a synchronous wrapper using manually created fences. This is seemingly fundamentally incompatible with the conversion pass to
hal_inline
unless the latter were to be capable of removing fences (which does not seem trivial in general).Yeah, I think the current FuncConversion is only designed to work with the full HAL. To use the inline HAL we'll need a mode for the conversion to switch to only generating the synchronous function without any of the full HAL types (fences/devices/etc).