iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Compilation error for SHARK-TestSuite (onnx/models/RAFT_vaiq_int8)

IanWood1 opened this issue · comments

EDIT (also added to reproduction steps):
The problem occurs during LLVMCPUVectorTransferLowering during canonicalization and can be reproduced with https://gist.github.com/IanWood1/59153bb58858c69b0569a6a6f39e3289 and running it with:

What happened?

Compilation fails due to excessive stack allocations

SHARK-TestSuite/e2eshark/test-run/onnx/models/RAFT_vaiq_int8/RAFT_vaiq_int8.default.onnx.linalg.mlir:6453:13: error: 'func.func' op exceeded stack allocation limit of 32768 bytes for function. Got 401408 bytes
    %1024 = linalg.generic {indexing_maps = [#map20, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%extracted_slice_440 : tensor<1024x7x7x1xf32>) outs(%1020 : tensor<1024x7x7x1xi8>) {
            ^

A bunch of consecutive vector.extract and vector.store ops are generated

Verbose output
    "vector.store"(%151559, %1030, %0, %1015, %1019, %1025) <{nontemporal = false}> : (vector<2xf32>, memref<1024x7x7x2xf32>, index, index, index, index) -> ()
    %151560 = "vector.extract"(%101387) <{static_position = array<i64: 1023, 6, 3>}> : (vector<1024x7x7x2xf32>) -> vector<2xf32>
    "vector.store"(%151560, %1030, %0, %1015, %1018, %1025) <{nontemporal = false}> : (vector<2xf32>, memref<1024x7x7x2xf32>, index, index, index, index) -> ()
    %151561 = "vector.extract"(%101387) <{static_position = array<i64: 1023, 6, 4>}> : (vector<1024x7x7x2xf32>) -> vector<2xf32>
    "vector.store"(%151561, %1030, %0, %1015, %1017, %1025) <{nontemporal = false}> : (vector<2xf32>, memref<1024x7x7x2xf32>, index, index, index, index) -> ()
    %151562 = "vector.extract"(%101387) <{static_position = array<i64: 1023, 6, 5>}> : (vector<1024x7x7x2xf32>) -> vector<2xf32>
    "vector.store"(%151562, %1030, %0, %1015, %1016, %1025) <{nontemporal = false}> : (vector<2xf32>, memref<1024x7x7x2xf32>, index, index, index, index) -> ()
    %151563 = "vector.extract"(%101387) <{static_position = array<i64: 1023, 6, 6>}> : (vector<1024x7x7x2xf32>) -> vector<2xf32>
    "vector.store"(%151563, %1030, %0, %1015, %1015, %1025) <{nontemporal = false}> : (vector<2xf32>, memref<1024x7x7x2xf32>, index, index, index, index) -> ()
    %151564 = "memref.subview"(%1030) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 1>, static_sizes = array<i64: 1024, 7, 7, 1>, static_strides = array<i64: 1, 1, 1, 1>}> : (memref<1024x7x7x2xf32>) -> memref<1024
x7x7xf32, strided<[98, 14, 2], offset: 1>>
    %151565 = "hal.interface.workgroup.id"() {dimension = 0 : index} : () -> index
    %151566 = "affine.apply"(%151565) <{map = affine_map<()[s0] -> (s0 * 128)>}> : (index) -> index
    %151567 = "memref.subview"(%1032, %151566) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808, 0, 0>, static_sizes = array<i64: 128, 7, 7>, static_strides = array<i64: 1, 1, 1>}> : (memref<1024x7x7xi8>,
index) -> memref<128x7x7xi8, strided<[49, 7, 1], offset: ?>>
    "scf.for"(%1025, %1026, %1027) ({
    ^bb0(%arg0: index):
      "scf.for"(%1025, %1028, %1027) ({
      ^bb0(%arg1: index):
        %151568 = "arith.addi"(%arg0, %151566) <{overflowFlags = #arith.overflow<none>}> : (index, index) -> index
        %151569 = "scf.for"(%1025, %1028, %1027, %1023) ({
        ^bb0(%arg2: index, %arg3: vector<7xf32>):
          %151578 = "memref.load"(%151564, %151568, %arg1, %arg2) <{nontemporal = false}> : (memref<1024x7x7xf32, strided<[98, 14, 2], offset: 1>>, index, index, index) -> f32
          %151579 = "vector.insertelement"(%151578, %arg3, %arg2) : (f32, vector<7xf32>, index) -> vector<7xf32>
          "scf.yield"(%151579) : (vector<7xf32>) -> ()
        }) : (index, index, index, vector<7xf32>) -> vector<7xf32>
        %151570 = "arith.divf"(%151569, %1024) <{fastmath = #arith.fastmath<none>}> : (vector<7xf32>, vector<7xf32>) -> vector<7xf32>
        %151571 = "math.round"(%151570) <{fastmath = #arith.fastmath<none>}> : (vector<7xf32>) -> vector<7xf32>
        %151572 = "arith.addf"(%151571, %1023) <{fastmath = #arith.fastmath<none>}> : (vector<7xf32>, vector<7xf32>) -> vector<7xf32>
        %151573 = "arith.cmpf"(%151572, %1022) <{fastmath = #arith.fastmath<none>, predicate = 11 : i64}> : (vector<7xf32>, vector<7xf32>) -> vector<7xi1>
        %151574 = "arith.cmpf"(%151572, %1021) <{fastmath = #arith.fastmath<none>, predicate = 9 : i64}> : (vector<7xf32>, vector<7xf32>) -> vector<7xi1>
        %151575 = "arith.select"(%151573, %1022, %151572) : (vector<7xi1>, vector<7xf32>, vector<7xf32>) -> vector<7xf32>
        %151576 = "arith.select"(%151574, %1021, %151575) : (vector<7xi1>, vector<7xf32>, vector<7xf32>) -> vector<7xf32>
        %151577 = "arith.fptosi"(%151576) : (vector<7xf32>) -> vector<7xi8>
        "vector.store"(%151577, %151567, %arg0, %arg1, %1025) <{nontemporal = false}> : (vector<7xi8>, memref<128x7x7xi8, strided<[49, 7, 1], offset: ?>>, index, index, index) -> ()
        "scf.yield"() : () -> ()
      }) : (index, index, index) -> ()
      "scf.yield"() : () -> ()
    }) : (index, index, index) -> ()
    "func.return"() : () -> ()
  }) {translation_info = #iree_codegen.translation_info<CPUDoubleTilingExpert>} : () -> ()
}) : () -> ()
"hal.executable.variant_end"() : () -> ()
}) {sym_name = "embedded_elf_x86_64", target = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64", {cpu = "znver3", cpu_features = "+prfchw,-cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,+xsaves,-avx512fp16,-usermsr,-sm4,-egpr,+sse4.1,-avx512ifma,+x
save,-avx512pf,+sse4.2,-tsxldtrk,-ptwrite,-widekl,-sm3,+invpcid,+64bit,+xsavec,-avx10.1-512,-avx512vpopcntdq,+cmov,-avx512vp2intersect,-avx512cd,+movbe,-avxvnniint8,-avx512er,-ccmp,-amx-int8,-kl,-avx10.1-256,-sha512,-avxvnni,-rtm,+adx,+avx2,-hreset,-movd
iri,-serialize,+vpclmulqdq,-avx512vl,-uintr,-cf,+clflushopt,-raoint,-cmpccxadd,+bmi,-amx-tile,+sse,-gfni,-avxvnniint16,-amx-fp16,-ndd,+xsaveopt,+rdrnd,-avx512f,-amx-bf16,-avx512bf16,-avx512vnni,-push2pop2,+cx8,-avx512bw,+sse3,-pku,+fsgsbase,+clzero,-mwai
tx,-lwp,+lzcnt,+sha,-movdir64b,-ppx,-wbnoinvd,-enqcmd,-prefetchwt1,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,+cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,-avx512bitalg,+rdpru,+clwb,+mmx,+sse2,+rdseed,-avx512vbmi2,-prefetchi,+rdpid,-fma4,-avx512vbmi,+shs
tk,+vaes,-waitpkg,-sgx,+fxsr,-avx512dq,+sse4a", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 32 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>} : () -> ()
  %1081 = linalg.generic {indexing_maps = [#map20, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%extracted_slice_458 : tensor<1024x7x7x1xf32>) outs(%1020 : tensor<1024x7x7x1xi8>) {
          ^

Steps to reproduce your issue

Noting that this issue also occurs with some other models. In the SHARK-TestSuite, the onnx/models/RAFT_vaiq_int8 also encounters a similar issue. To reproduce, set up the test suite, and run

python run.py --cachedir=/path/to/.cache/ -t onnx/models/RAFT_vaiq_int8/ -m onnx -c /path/to/torch-mlir/build/ -i /path/to/iree-build/ --torchtolinalg

with an up-to-date torch-mlir and iree build.

Originally posted by @zjgarvey in #17226 (comment)

Minimal repro

#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map20 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>
func.func public @jit_eval_174(%1104 :  tensor<1024x7x7x2xi8>) -> tensor<1024x7x7x2xi8> {
    %cst_9 = arith.constant 2.00 : f32
    %cst_4 = arith.constant 4.00 : f32
    %cst_0 = arith.constant 0.00 : f32
    %cst_1 = arith.constant 1.00 : f32

    %1015 = tensor.empty() : tensor<1024x7x7x2xf32>
    %1105 = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1104 : tensor<1024x7x7x2xi8>) outs(%1015 : tensor<1024x7x7x2xf32>) {
    ^bb0(%in: i8, %out: f32):
      %3555 = arith.extsi %in : i8 to i32
      %3556 = arith.sitofp %3555 : i32 to f32
      %3557 = arith.mulf %3556, %cst_9 : f32
      linalg.yield %3557 : f32
    } -> tensor<1024x7x7x2xf32>

    %cst_218 = arith.constant dense<1.000000e+00> : tensor<f32>

    %1020 = tensor.empty() : tensor<1024x7x7x1xi8>
    %1022 = tensor.empty() : tensor<1024x7x7x1xf32>
    %extracted_slice_466 = tensor.extract_slice %1105[0, 0, 0, 0] [1024, 7, 7, 1] [1, 1, 1, 1] : tensor<1024x7x7x2xf32> to tensor<1024x7x7x1xf32>
    %extracted_slice_467 = tensor.extract_slice %1105[0, 0, 0, 1] [1024, 7, 7, 1] [1, 1, 1, 1] : tensor<1024x7x7x2xf32> to tensor<1024x7x7x1xf32>

     %1106 = linalg.generic {indexing_maps = [#map20, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%extracted_slice_466 : tensor<1024x7x7x1xf32>) outs(%1020 : tensor<1024x7x7x1xi8>) {
    ^bb0(%in: f32, %out: i8):
      %3555 = arith.divf %in, %cst_9 : f32
      %3556 = math.round %3555 : f32
      %3557 = arith.addf %3556, %cst_4 : f32
      %3558 = arith.cmpf ult, %3557, %cst_0 : f32
      %3559 = arith.cmpf ugt, %3557, %cst_1 : f32
      %3560 = arith.select %3558, %cst_0, %3557 : f32
      %3561 = arith.select %3559, %cst_1, %3560 : f32
      %3562 = arith.fptosi %3561 : f32 to i8
      linalg.yield %3562 : i8
    } -> tensor<1024x7x7x1xi8>

    %1108 = linalg.generic {indexing_maps = [#map20, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%extracted_slice_467 : tensor<1024x7x7x1xf32>) outs(%1020 : tensor<1024x7x7x1xi8>) {
    ^bb0(%in: f32, %out: i8):
      %3555 = arith.divf %in, %cst_9 : f32
      %3556 = math.round %3555 : f32
      %3557 = arith.addf %3556, %cst_4 : f32
      %3558 = arith.cmpf ult, %3557, %cst_0 : f32
      %3559 = arith.cmpf ugt, %3557, %cst_1 : f32
      %3560 = arith.select %3558, %cst_0, %3557 : f32
      %3561 = arith.select %3559, %cst_1, %3560 : f32
      %3562 = arith.fptosi %3561 : f32 to i8
      linalg.yield %3562 : i8
    } -> tensor<1024x7x7x1xi8>

 
    %concat_468 = tensor.concat dim(3) %1108, %1106 : (tensor<1024x7x7x1xi8>, tensor<1024x7x7x1xi8>) -> tensor<1024x7x7x2xi8>
    return %concat_468: tensor<1024x7x7x2xi8>
} 

run with:

iree-compile --iree-hal-target-backends=llvm-cpu --iree-input-demote-i64-to-i3 path/to/mlir.mlir

Additional context

#17226
#17341

Here is a second (more concise) example and corresponding iree-compile logs
MLIR
logs

also, here is the logs from the original repro https://gist.github.com/IanWood1/cc3e732c49796b4ce9e0300824b57b3e