iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[LLVMCPU][UKernels] Pack and unpack ukernels not passing appropriate strides

Max191 opened this issue · comments

The lowering to ukernels for pack and unpack ops fails to capture necessary striding information. For example, in the following IR, the inner dimension is strided, but the ukernel does not capture that information:

// -----// IR Dump After LowerUKernelOpsToCalls (iree-codegen-lower-ukernel-ops-to-calls) //----- //
module {
  func.func private @iree_uk_pack(memref<bf16>, index, index, memref<bf16>, index, index, index, index, index, index, index, index, i64, i32) -> i32 attributes {hal.import.bitcode = true, hal.import.fields = ["processor_data"], llvm.bareptr = true}
  func.func @main$async_dispatch_3_pack_bf16() attributes {translation_info = #iree_codegen.translation_info<CPUDataTiling>} {
    %c5_i32 = arith.constant 5 : i32
    %c2 = arith.constant 2 : index
    %c16 = arith.constant 16 : index
    %c0_i64 = arith.constant 0 : i64
    %c64 = arith.constant 64 : index
    %c8 = arith.constant 8 : index
    %c0 = arith.constant 0 : index
    %c247808 = arith.constant 247808 : index
    %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<121x16x64xbf16>
    memref.assume_alignment %0, 64 : memref<121x16x64xbf16>
    %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c247808) : memref<64x8x8x16x2xbf16, strided<[2048, 256, 32, 2, 1], offset: 123904>>
    memref.assume_alignment %1, 64 : memref<64x8x8x16x2xbf16, strided<[2048, 256, 32, 2, 1], offset: 123904>>
    %workgroup_id_x = hal.interface.workgroup.id[0] : index
    %workgroup_count_x = hal.interface.workgroup.count[0] : index
    %workgroup_id_y = hal.interface.workgroup.id[1] : index
    %workgroup_count_y = hal.interface.workgroup.count[1] : index
    %2 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%workgroup_id_x]
    %3 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%workgroup_count_x]
    scf.for %arg0 = %workgroup_id_y to %c64 step %workgroup_count_y {
      scf.for %arg1 = %2 to %c8 step %3 {
        %subview = memref.subview %1[%arg0, %arg1, 0, 0, 0] [1, 2, 8, 16, 2] [1, 1, 1, 1, 1] : memref<64x8x8x16x2xbf16, strided<[2048, 256, 32, 2, 1], offset: 123904>> to memref<1x2x8x16x2xbf16, strided<[2048, 256, 32, 2, 1], offset: ?>>
        %4 = affine.min affine_map<(d0) -> (d0 * -16 + 121, 32)>(%arg1)
        %5 = affine.apply affine_map<(d0) -> (d0 * 16)>(%arg1)
        %subview_0 = memref.subview %0[%5, 0, %arg0] [%4, 16, 1] [1, 1, 1] : memref<121x16x64xbf16> to memref<?x16x1xbf16, strided<[1024, 64, 1], offset: ?>>
        %subview_1 = memref.subview %subview_0[0, 0, 0] [%4, 16, 1] [1, 1, 1] : memref<?x16x1xbf16, strided<[1024, 64, 1], offset: ?>> to memref<?x16xbf16, strided<[1024, 64], offset: ?>>
        %subview_2 = memref.subview %subview[0, 0, 0, 0, 0] [1, 2, 8, 16, 2] [1, 1, 1, 1, 1] : memref<1x2x8x16x2xbf16, strided<[2048, 256, 32, 2, 1], offset: ?>> to memref<2x8x16x2xbf16, strided<[256, 32, 2, 1], offset: ?>>
        %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview_1 : memref<?x16xbf16, strided<[1024, 64], offset: ?>> -> memref<bf16>, index, index, index, index, index
        %base_buffer_3, %offset_4, %sizes_5:4, %strides_6:4 = memref.extract_strided_metadata %subview_2 : memref<2x8x16x2xbf16, strided<[256, 32, 2, 1], offset: ?>> -> memref<bf16>, index, index, index, index, index, index, index, index, index
        %6 = func.call @iree_uk_pack(%base_buffer, %offset, %strides#0, %base_buffer_3, %offset_4, %strides_6#0, %4, %c16, %c2, %c8, %c16, %c2, %c0_i64, %c5_i32) : (memref<bf16>, index, index, memref<bf16>, index, index, index, index, index, index, index, index, i64, i32) -> i32
      }
    }
    return
  }
}

This happens because the strided_outer_dims is set to 1 for all pack and unpack ukernels:

/*strided_outer_dims=*/rewriter.getIndexAttr(1));

This is probably intentional, as it seems that Ukernels may not support cases where inner dimensions are strided. It would be great to be able to support strided cases in the pack ukernels, and then fix this lowering issue, but the lowering should fail at the very least, since it causes correctness issues.