iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Number of dims and results of reindexed AffineMap doesn't match on Vectorization

jinchen62 opened this issue · comments

What happened?

dispatch: https://gist.github.com/jinchen62/5e2af98f9b5bfc3b55e949f964459815
error log: https://gist.github.com/jinchen62/df2038b5a43ed4680804a3d7d0647d95

The failing op dumped at https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp#L336 is

%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (0, d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<1x4xf32>) outs(%arg2 : tensor<1x1xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 0], [1, 0], [0, 4], [0, 0]]>} {
^bb0(%in: f32, %out: f32):
%11 = arith.addf %in, %out : f32
linalg.yield %11 : f32
} -> tensor<1x1xf32>

At the assertion failing point https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp#L474, the map is changed from (d0, d1) -> (0, d0) to (d0) -> (0, d0) so the number of dims and results doesn't match.

Steps to reproduce your issue

Run iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu dispatch_1.mlir -o test.vmfb 2> dump.mlir with TOM iree.

What component(s) does this issue relate to?

No response

Version information

No response

Additional context

No response

Inlining the mlir input below. In the beginning, I thought that the (d0, d1) -> (0, d0) is generated during codegen, but it is the case. There is (d0, d1) -> (0, d0) affine_map in the codegen's input. @jinchen62 do you know how the input is generated? It would be very helpful if you can track it back to a small set of linalg ops or tosa/torch ops. The 0 should be folded away by FoldUnitExtentDimsPass at global opt level or flow level, i.e., it should be (d0, d1) -> (d0) when it goes to codegen.

hal.executable public @main_graph$async_dispatch_1 {
  hal.executable.variant public @embedded_elf_x86_64 target(<"llvm-cpu", "embedded-elf-x86_64", {cpu = "generic", cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 16 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>) {
    hal.executable.export public @main_graph$async_dispatch_1_generic_9x1024_f32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>]} {
    ^bb0(%arg0: !hal.device):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
      hal.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @main_graph$async_dispatch_1_generic_9x1024_f32() {
        %cst = arith.constant 0.000000e+00 : f32
        %0 = hal.interface.constant.load[0] : i32
        %1 = hal.interface.constant.load[1] : i32
        %2 = arith.index_castui %0 : i32 to index
        %3 = arith.index_castui %1 : i32 to index
        %4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%2) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<9x1024xf32>>
        %5 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%3) : !flow.dispatch.tensor<writeonly:tensor<1x9xf32>>
        %6 = flow.dispatch.tensor.load %4, offsets = [0, 0], sizes = [9, 1024], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<9x1024xf32>> -> tensor<9x1024xf32>
        %7 = tensor.empty() : tensor<1x9xf32>
        %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x9xf32>) -> tensor<1x9xf32>
        %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (0, d0)>], iterator_types = ["parallel", "reduction"]} ins(%6 : tensor<9x1024xf32>) outs(%8 : tensor<1x9xf32>) {
        ^bb0(%in: f32, %out: f32):
          %10 = arith.addf %in, %out : f32
          linalg.yield %10 : f32
        } -> tensor<1x9xf32>
        flow.dispatch.tensor.store %9, %5, offsets = [0, 0], sizes = [1, 9], strides = [1, 1] : tensor<1x9xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x9xf32>>
        return
      }
    }
  }
}

I worked with @jinchen62 and we got a smaller repro: https://gist.github.com/hanhanW/b3652f5887b93fb8f0df6c6c39c1ef87

To repro, run iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-fold-unit-extent-dims))" ~/repro.mlir.

Then you'll see affine_map<(d0, d1) -> (0, d0)> in the result.

#map2 = affine_map<(d0, d1) -> (d0, d1)>
#map8 = affine_map<(d0, d1) -> (0, d0)>
// ...
    %29 = linalg.generic {indexing_maps = [#map2, #map8], iterator_types = ["parallel", "reduction"]} ins(%collapsed_12 : tensor<9x1024xf32>) outs(%28 : tensor<?x9xf32>) {
    ^bb0(%in: f32, %out: f32):
      %35 = arith.addf %in, %out : f32
      linalg.yield %35 : f32
    } -> tensor<?x9xf32>
// ...

Actually, the input reduction op looks weird. The size of d0 mismatch. One is 1 and the other is ? It looks like there is a bug in frontend lowering. @jinchen62 you can add -mlir-print-debuginfo to iree-compile, and it will tell you where is the op lowered from. My guess is that there is a bug in XXX->Linalg lowering.

#map5 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map10 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
    %25 = tensor.empty(%12) : tensor<?x9x1xf32>
    %26 = linalg.fill ins(%cst_7 : f32) outs(%25 : tensor<?x9x1xf32>) -> tensor<?x9x1xf32>
    %27 = linalg.generic {indexing_maps = [#map5, #map10], iterator_types = ["parallel", "parallel", "reduction"]} ins(%24 : tensor<1x9x1024xf32>) outs(%26 : tensor<?x9x1xf32>) {
    ^bb0(%in: f32, %out: f32):
      %31 = arith.addf %in, %out : f32
      linalg.yield %31 : f32
    } -> tensor<?x9x1xf32>

smaller repro: https://gist.github.com/jinchen62/91e216fb39abbb9ba4c0461346d2bb5a

command:
iree-opt --pass-pipeline="builtin.module(util.func(iree-flow-fold-unit-extent-dims))" repro.mlir
or
iree-compile --iree-hal-target-backends=llvm-cpu repro.mlir -o test.vmfb --mlir-print-ir-after-all 2> dump.mlir

@jinchen62 did you get a chance to see which op is generating the IR? The generic op looks invalid to me, like I explained in the above comment.

I think it's

%237 = torch.aten.sum.dim_IntList %235, %236, %true, %none : !torch.vtensor<[?,9,1024],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[?,9,1],f32>

I'd suggest to check if there are bugs in torch -> linalg lowering, or other high level dialects -> torch lowering.

torch level repro: https://gist.github.com/jinchen62/601cfce290b81e037383fc49b604a68a

iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu --iree-util-zero-fill-elided-attrs repro_torch.mlir -o test.vmfb

part of dump torch repro:
After ExpandOps (memref-expand) -> After Canonicalizer (canonicalize)
https://gist.github.com/jinchen62/ae856e42b0660d0b41426e910039fb9a

@hanhanW I think with a tensor.cast op, the reduction op that you found weird should be good to compile like line381. But after Canonicalizer pass, it looks missing it like line817. The following is a compiled repro, it would fail on the same error that we are facing without the cast op at the end. Does it make sense?

#map = affine_map<(d0) -> (d0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1, 0)>
module {
  func.func @repro2(%arg0: tensor<1x9x1024xf32>) -> tensor<1x9x1xf32> {
    %cst = arith.constant dense<[false, true]> : tensor<2xi1>
    %cst_0 = arith.constant dense<1> : tensor<2xi32>
    %cst_1 = arith.constant dense<[1, -1]> : tensor<2xi32>
    %cst_2 = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<2xi32>
    %1 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%cst, %cst_0, %cst_1 : tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) outs(%0 : tensor<2xi32>) {
    ^bb0(%in: i1, %in_3: i32, %in_4: i32, %out: i32):
      %6 = arith.select %in, %in_3, %in_4 : i32
      linalg.yield %6 : i32
    } -> tensor<2xi32>
    %extracted_slice = tensor.extract_slice %1[0] [1] [1] : tensor<2xi32> to tensor<1xi32>
    %collapsed = tensor.collapse_shape %extracted_slice [] : tensor<1xi32> into tensor<i32>
    %extracted = tensor.extract %collapsed[] : tensor<i32>
    %2 = arith.index_cast %extracted : i32 to index
    %3 = tensor.empty(%2) : tensor<?x9x1xf32>
    %4 = linalg.fill ins(%cst_2 : f32) outs(%3 : tensor<?x9x1xf32>) -> tensor<?x9x1xf32>
    %5 = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0 : tensor<1x9x1024xf32>) outs(%4 : tensor<?x9x1xf32>) {
    ^bb0(%in: f32, %out: f32):
      %6 = arith.addf %in, %out : f32
      linalg.yield %6 : f32
    } -> tensor<?x9x1xf32>
    %cast = tensor.cast %5 : tensor<?x9x1xf32> to tensor<1x9x1xf32>
    return %cast : tensor<1x9x1xf32>
  }
}

I'm not convinced that the issue is tensor.cast. There are some shape inference passes/patterns in MLIR dialect, and they create tensor.cast op to spell out some static shapes. With the hint, the compiler is smart to fold the shape information into linalg op, which is reasonable to me. The patterns and passes are working at Linalg level, what I can think of is that the frontend is generating invalid ops.

I don't know why we're still triaging the issue at model level, perhaps I did not make it clear. Let me put it this way -- Instead of compiling the whole model, are you able to compile a single %237 = torch.aten.sum.dim_IntList %235, %236, %true, %none : !torch.vtensor<[?,9,1024],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[?,9,1],f32> op?

I don't think it's a lowering issue. The torch.aten.sum.dim_IntList op compiles, and I traced up to the onnx->torch and didn't find a lowering bug of any onnx op.

@raikonenfnu and I think there might be an optimization bug in canonicalize pass after memref-expand. We saw the generic op with reduction dim changing from ins(%146 : tensor<?x9x1024xf32>) outs(%149 : tensor<?x?x1xf32>) to ins(%55 : tensor<1x9x1024xf32>) outs(%57 : tensor<?x9x1xf32>) with folding the tensor.cast op. We might want to see it changes to ins(%55 : tensor<?x9x1024xf32>) outs(%57 : tensor<?x9x1xf32>) or ins(%55 : tensor<1x9x1024xf32>) outs(%57 : tensor<1x9x1xf32>). The dump ir is here.

I don't think it's a lowering issue. The torch.aten.sum.dim_IntList op compiles, and I traced up to the onnx->torch and didn't find a lowering bug of any onnx op.

@raikonenfnu and I think there might be an optimization bug in canonicalize pass after memref-expand. We saw the generic op with reduction dim changing from ins(%146 : tensor<?x9x1024xf32>) outs(%149 : tensor<?x?x1xf32>) to ins(%55 : tensor<1x9x1024xf32>) outs(%57 : tensor<?x9x1xf32>) with folding the tensor.cast op. We might want to see it changes to ins(%55 : tensor<?x9x1024xf32>) outs(%57 : tensor<?x9x1xf32>) or ins(%55 : tensor<1x9x1024xf32>) outs(%57 : tensor<1x9x1xf32>). The dump ir is here.

@jinchen62 So what's the plan to fix this issue? The bart-large model need this anyway.