[LLVMCPU][UKernels] Pack and unpack ukernels not passing appropriate strides
Max191 opened this issue · comments
The lowering to ukernels for pack and unpack ops fails to capture necessary striding information. For example, in the following IR, the inner dimension is strided, but the ukernel does not capture that information:
// -----// IR Dump After LowerUKernelOpsToCalls (iree-codegen-lower-ukernel-ops-to-calls) //----- //
module {
func.func private @iree_uk_pack(memref<bf16>, index, index, memref<bf16>, index, index, index, index, index, index, index, index, i64, i32) -> i32 attributes {hal.import.bitcode = true, hal.import.fields = ["processor_data"], llvm.bareptr = true}
func.func @main$async_dispatch_3_pack_bf16() attributes {translation_info = #iree_codegen.translation_info<CPUDataTiling>} {
%c5_i32 = arith.constant 5 : i32
%c2 = arith.constant 2 : index
%c16 = arith.constant 16 : index
%c0_i64 = arith.constant 0 : i64
%c64 = arith.constant 64 : index
%c8 = arith.constant 8 : index
%c0 = arith.constant 0 : index
%c247808 = arith.constant 247808 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<121x16x64xbf16>
memref.assume_alignment %0, 64 : memref<121x16x64xbf16>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c247808) : memref<64x8x8x16x2xbf16, strided<[2048, 256, 32, 2, 1], offset: 123904>>
memref.assume_alignment %1, 64 : memref<64x8x8x16x2xbf16, strided<[2048, 256, 32, 2, 1], offset: 123904>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%2 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%workgroup_id_x]
%3 = affine.apply affine_map<()[s0] -> (s0 * 2)>()[%workgroup_count_x]
scf.for %arg0 = %workgroup_id_y to %c64 step %workgroup_count_y {
scf.for %arg1 = %2 to %c8 step %3 {
%subview = memref.subview %1[%arg0, %arg1, 0, 0, 0] [1, 2, 8, 16, 2] [1, 1, 1, 1, 1] : memref<64x8x8x16x2xbf16, strided<[2048, 256, 32, 2, 1], offset: 123904>> to memref<1x2x8x16x2xbf16, strided<[2048, 256, 32, 2, 1], offset: ?>>
%4 = affine.min affine_map<(d0) -> (d0 * -16 + 121, 32)>(%arg1)
%5 = affine.apply affine_map<(d0) -> (d0 * 16)>(%arg1)
%subview_0 = memref.subview %0[%5, 0, %arg0] [%4, 16, 1] [1, 1, 1] : memref<121x16x64xbf16> to memref<?x16x1xbf16, strided<[1024, 64, 1], offset: ?>>
%subview_1 = memref.subview %subview_0[0, 0, 0] [%4, 16, 1] [1, 1, 1] : memref<?x16x1xbf16, strided<[1024, 64, 1], offset: ?>> to memref<?x16xbf16, strided<[1024, 64], offset: ?>>
%subview_2 = memref.subview %subview[0, 0, 0, 0, 0] [1, 2, 8, 16, 2] [1, 1, 1, 1, 1] : memref<1x2x8x16x2xbf16, strided<[2048, 256, 32, 2, 1], offset: ?>> to memref<2x8x16x2xbf16, strided<[256, 32, 2, 1], offset: ?>>
%base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview_1 : memref<?x16xbf16, strided<[1024, 64], offset: ?>> -> memref<bf16>, index, index, index, index, index
%base_buffer_3, %offset_4, %sizes_5:4, %strides_6:4 = memref.extract_strided_metadata %subview_2 : memref<2x8x16x2xbf16, strided<[256, 32, 2, 1], offset: ?>> -> memref<bf16>, index, index, index, index, index, index, index, index, index
%6 = func.call @iree_uk_pack(%base_buffer, %offset, %strides#0, %base_buffer_3, %offset_4, %strides_6#0, %4, %c16, %c2, %c8, %c16, %c2, %c0_i64, %c5_i32) : (memref<bf16>, index, index, memref<bf16>, index, index, index, index, index, index, index, index, i64, i32) -> i32
}
}
return
}
}
This happens because the strided_outer_dims
is set to 1
for all pack and unpack ukernels:
This is probably intentional, as it seems that Ukernels may not support cases where inner dimensions are strided. It would be great to be able to support strided cases in the pack ukernels, and then fix this lowering issue, but the lowering should fail at the very least, since it causes correctness issues.