openxla / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Flow] Fusion of `tensor.unpack` ops with reduction `linalg.generic` ops discussion

Max191 opened this issue · comments

I have the following IR that I want to end up in a single dispatch:

#map = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1) -> (d1)>
#map2 = affine_map<(d0, d1) -> (d0, d1)>
#map3 = affine_map<(d0, d1) -> (d0)>
module {
  func.func @unpack_reduction(%arg0: tensor<32x1x128x1x32xi32>, %arg1: tensor<32xf32>, %arg2: tensor<32xf32>, %arg3: tensor<4096x32xf32>, %arg4: tensor<4096x32xf32>) -> tensor<4096xf32> {
    %0 = tensor.empty() : tensor<32x1x4096xi32>
    %1 = tensor.empty() : tensor<4096xf32>
    %unpack = tensor.unpack %arg0 inner_dims_pos = [1, 2] inner_tiles = [1, 32] into %0 : tensor<32x1x128x1x32xi32> -> tensor<32x1x4096xi32>
    %collapsed = tensor.collapse_shape %unpack [[0, 1], [2]] : tensor<32x1x4096xi32> into tensor<32x4096xi32>
    %2 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map2, #map2, #map3], iterator_types = ["parallel", "reduction"]} ins(%collapsed, %arg1, %arg2, %arg3, %arg4 : tensor<32x4096xi32>, tensor<32xf32>, tensor<32xf32>, tensor<4096x32xf32>, tensor<4096x32xf32>) outs(%1 : tensor<4096xf32>) {
    ^bb0(%in: i32, %in_0: f32, %in_1: f32, %in_2: f32, %in_3: f32, %out: f32):
      %3 = arith.sitofp %in : i32 to f32
      %4 = arith.mulf %3, %in_0 : f32
      %5 = arith.mulf %4, %in_2 : f32
      %6 = arith.mulf %in_3, %in_2 : f32
      %7 = arith.mulf %6, %in_1 : f32
      %8 = arith.subf %5, %7 : f32
      %9 = arith.addf %8, %out : f32
      linalg.yield %9 : f32
    } -> tensor<4096xf32>
    return %2 : tensor<4096xf32>
  }
}

Two things need to change for this to happen:

  1. The tensor.collapse_shape must be propagated through the linalg.generic op. This requires some changes to the FoldWithProducerReshapeOpByExpansion pattern upstream. Right now it doesn't support cases with reduction dimensions, but in cases where the collapsed/expanded dimension is a unit dimension, it is not difficult to support these cases. Unit dimension collapse_shape ops like this will actually be quite common for matvec workloads, because ExpandVectors will create these ops to turn matvec/vecmat into matmul.
  2. The tensor.unpack must be allowed to fuse with a linalg.generic op that has reduction dimensions. This would require some analysis of which dimensions of the unpacked shape are reduced in the linalg.generic consumer. When the reduced dimension exists in both the packed and unpacked shape (i.e., when the tensor.unpack op does not touch the reduced dimension), then this fusion should be possible. This type of reduction will probably be less common, since the only dimension that is typically left untouched by unpack is the batch dimension, but it will arise from quantized matmul workloads that get reassociated on CPU.

I made some makeshift changes to FoldWithProducerReshapeOpByExpansion, and a dummy change in dispatch formation to enable fusion so I could look at the lowering. I got the following dump after all for the above function:
https://drive.google.com/file/d/1kijbflzG09UF8utV2y7t0fjtD0N7jL8X/view?usp=sharing

Aside from some sub-optimal vector tile selection, the dump looks okay to me. I think it would make sense to try to allow fusions like this to happen.

CC @MaheshRavishankar @hanhanW

Putting the critical snippet to here, so people don't need to look at the whole dump:

// -----// IR Dump After CSE (cse) //----- //
func.func @unpack_reduction(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
  %0 = hal.tensor.import %arg0 "input 0" : !hal.buffer_view -> tensor<32x1x128x1x32xi32>
  %1 = hal.tensor.import %arg1 "input 1" : !hal.buffer_view -> tensor<32xf32>
  %2 = hal.tensor.import %arg2 "input 2" : !hal.buffer_view -> tensor<32xf32>
  %3 = hal.tensor.import %arg3 "input 3" : !hal.buffer_view -> tensor<4096x32xf32>
  %4 = hal.tensor.import %arg4 "input 4" : !hal.buffer_view -> tensor<4096x32xf32>
  %5 = tensor.empty() : tensor<32x1x4096xi32>
  %6 = tensor.empty() : tensor<4096xf32>
  %unpack = tensor.unpack %0 inner_dims_pos = [1, 2] inner_tiles = [1, 32] into %5 : tensor<32x1x128x1x32xi32> -> tensor<32x1x4096xi32>
  %collapsed = tensor.collapse_shape %unpack [[0, 1], [2]] : tensor<32x1x4096xi32> into tensor<32x4096xi32>
  %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%collapsed, %1, %2, %3, %4 : tensor<32x4096xi32>, tensor<32xf32>, tensor<32xf32>, tensor<4096x32xf32>, tensor<4096x32xf32>) outs(%6 : tensor<4096xf32>) {
  ^bb0(%in: i32, %in_0: f32, %in_1: f32, %in_2: f32, %in_3: f32, %out: f32):
    %9 = arith.sitofp %in : i32 to f32
    %10 = arith.mulf %9, %in_0 : f32
    %11 = arith.mulf %10, %in_2 : f32
    %12 = arith.mulf %in_3, %in_2 : f32
    %13 = arith.mulf %12, %in_1 : f32
    %14 = arith.subf %11, %13 : f32
    %15 = arith.addf %14, %out : f32
    linalg.yield %15 : f32
  } -> tensor<4096xf32>
  %8 = hal.tensor.export %7 "output 0" : tensor<4096xf32> -> !hal.buffer_view
  return %8 : !hal.buffer_view
}

// -----// IR Dump After FusionOfTensorOps (iree-flow-fusion-of-tensor-ops) //----- //
func.func @unpack_reduction(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
  %0 = hal.tensor.import %arg0 "input 0" : !hal.buffer_view -> tensor<32x1x128x1x32xi32>
  %1 = hal.tensor.import %arg1 "input 1" : !hal.buffer_view -> tensor<32xf32>
  %2 = hal.tensor.import %arg2 "input 2" : !hal.buffer_view -> tensor<32xf32>
  %3 = hal.tensor.import %arg3 "input 3" : !hal.buffer_view -> tensor<4096x32xf32>
  %4 = hal.tensor.import %arg4 "input 4" : !hal.buffer_view -> tensor<4096x32xf32>
  %5 = tensor.empty() : tensor<32x1x4096xi32>
  %6 = tensor.empty() : tensor<4096xf32>
  %unpack = tensor.unpack %0 inner_dims_pos = [1, 2] inner_tiles = [1, 32] into %5 : tensor<32x1x128x1x32xi32> -> tensor<32x1x4096xi32>
  %expanded = tensor.expand_shape %1 [[0, 1]] : tensor<32xf32> into tensor<32x1xf32>
  %expanded_0 = tensor.expand_shape %2 [[0, 1]] : tensor<32xf32> into tensor<32x1xf32>
  %expanded_1 = tensor.expand_shape %3 [[0], [1, 2]] : tensor<4096x32xf32> into tensor<4096x32x1xf32>
  %expanded_2 = tensor.expand_shape %4 [[0], [1, 2]] : tensor<4096x32xf32> into tensor<4096x32x1xf32>
  %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%unpack, %expanded, %expanded_0, %expanded_1, %expanded_2 : tensor<32x1x4096xi32>, tensor<32x1xf32>, tensor<32x1xf32>, tensor<4096x32x1xf32>, tensor<4096x32x1xf32>) outs(%6 : tensor<4096xf32>) {
  ^bb0(%in: i32, %in_3: f32, %in_4: f32, %in_5: f32, %in_6: f32, %out: f32):
    %9 = arith.sitofp %in : i32 to f32
    %10 = arith.mulf %9, %in_3 : f32
    %11 = arith.mulf %10, %in_5 : f32
    %12 = arith.mulf %in_6, %in_5 : f32
    %13 = arith.mulf %12, %in_4 : f32
    %14 = arith.subf %11, %13 : f32
    %15 = arith.addf %14, %out : f32
    linalg.yield %15 : f32
  } -> tensor<4096xf32>
  %8 = hal.tensor.export %7 "output 0" : tensor<4096xf32> -> !hal.buffer_view
  return %8 : !hal.buffer_view
}

If I read correctly, you modify FusionOnTensors logic. So the result of unpack op can be passed to generic ops?

If I read correctly, you modify FusionOnTensors logic. So the result of unpack op can be passed to generic ops?

Thanks for posting the IR snippet. That's right, I modified one of the upstream patterns (FoldWithProducerReshapeOpByExpansion) used in FusionOfTensorOps. Then I also modified the fusion logic in FormDispatchRegions so they can be fused:

// -----// IR Dump After FormDispatchRegions (iree-flow-form-dispatch-regions) //----- //
func.func @unpack_reduction(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view, %arg2: !hal.buffer_view, %arg3: !hal.buffer_view, %arg4: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
  %0 = hal.tensor.import %arg0 "input 0" : !hal.buffer_view -> tensor<32x1x128x1x32xi32>
  %1 = hal.tensor.import %arg1 "input 1" : !hal.buffer_view -> tensor<32xf32>
  %2 = hal.tensor.import %arg2 "input 2" : !hal.buffer_view -> tensor<32xf32>
  %3 = hal.tensor.import %arg3 "input 3" : !hal.buffer_view -> tensor<4096x32xf32>
  %4 = hal.tensor.import %arg4 "input 4" : !hal.buffer_view -> tensor<4096x32xf32>
  %5 = tensor.empty() : tensor<32x1x4096xi32>
  %6 = tensor.empty() : tensor<4096xf32>
  %expanded = tensor.expand_shape %1 [[0, 1]] : tensor<32xf32> into tensor<32x1xf32>
  %expanded_0 = tensor.expand_shape %2 [[0, 1]] : tensor<32xf32> into tensor<32x1xf32>
  %expanded_1 = tensor.expand_shape %3 [[0], [1, 2]] : tensor<4096x32xf32> into tensor<4096x32x1xf32>
  %expanded_2 = tensor.expand_shape %4 [[0], [1, 2]] : tensor<4096x32xf32> into tensor<4096x32x1xf32>
  %7 = flow.dispatch.region -> (tensor<4096xf32>) {
    %unpack = tensor.unpack %0 inner_dims_pos = [1, 2] inner_tiles = [1, 32] into %5 : tensor<32x1x128x1x32xi32> -> tensor<32x1x4096xi32>
    %9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"]} ins(%unpack, %expanded, %expanded_0, %expanded_1, %expanded_2 : tensor<32x1x4096xi32>, tensor<32x1xf32>, tensor<32x1xf32>, tensor<4096x32x1xf32>, tensor<4096x32x1xf32>) outs(%6 : tensor<4096xf32>) {
    ^bb0(%in: i32, %in_3: f32, %in_4: f32, %in_5: f32, %in_6: f32, %out: f32):
      %10 = arith.sitofp %in : i32 to f32
      %11 = arith.mulf %10, %in_3 : f32
      %12 = arith.mulf %11, %in_5 : f32
      %13 = arith.mulf %in_6, %in_5 : f32
      %14 = arith.mulf %13, %in_4 : f32
      %15 = arith.subf %12, %14 : f32
      %16 = arith.addf %15, %out : f32
      linalg.yield %16 : f32
    } -> tensor<4096xf32>
    flow.return %9 : tensor<4096xf32>
  }
  %8 = hal.tensor.export %7 "output 0" : tensor<4096xf32> -> !hal.buffer_view
  return %8 : !hal.buffer_view
}

Reading through the issue so leaving comments as I go by.

Two things need to change for this to happen:

  1. The tensor.collapse_shape must be propagated through the linalg.generic op. This requires some changes to the FoldWithProducerReshapeOpByExpansion pattern upstream. Right now it doesn't support cases with reduction dimensions, but in cases where the collapsed/expanded dimension is a unit dimension, it is not difficult to support these cases. Unit dimension collapse_shape ops like this will actually be quite common for matvec workloads, because ExpandVectors will create these ops to turn matvec/vecmat into matmul.

It seems better that we just fuse the collapse_shape with the unpack in such cases. Essentially this about removing spurious unit dimensions. That is where this reshape is coming from. It just gets stuck at the unpack.

  1. The tensor.unpack must be allowed to fuse with a linalg.generic op that has reduction dimensions. This would require some analysis of which dimensions of the unpacked shape are reduced in the linalg.generic consumer. When the reduced dimension exists in both the packed and unpacked shape (i.e., when the tensor.unpack op does not touch the reduced dimension), then this fusion should be possible. This type of reduction will probably be less common, since the only dimension that is typically left untouched by unpack is the batch dimension, but it will arise from quantized matmul workloads that get reassociated on CPU.

This part makes sense. Codifying this logic shouldnt be too hard.

It seems better that we just fuse the collapse_shape with the unpack in such cases. Essentially this about removing spurious unit dimensions. That is where this reshape is coming from. It just gets stuck at the unpack.

(Flying by) Could we fuse more generic collapse/expand shapes with pack/unpack ops?

It seems better that we just fuse the collapse_shape with the unpack in such cases.
(Flying by) Could we fuse more generic collapse/expand shapes with pack/unpack ops?

How do we codegen a dispatch that has unpack + reshapes? The reshape ops are not tileable, so it will block distribution and tiling, and maybe result in huge stack allocation for holding temp data. If we expect reshape ops get folded away, why can't that happen before forming dispatches?

Maybe we should teach fusion to move reshape ops across pack/unpack ops, so they can naturally fuse with generic ops?

(I've been thinking how we can update the program graph better, but I haven't figured it out yet. I need more time to look at actual graph and prototype. I can prioritize adding folders to pack/unpack ops (#15604) and revisit if data-layout propagation helps or not.)

It seems better that we just fuse the collapse_shape with the unpack in such cases. Essentially this about removing spurious unit dimensions. That is where this reshape is coming from. It just gets stuck at the unpack.

(Flying by) Could we fuse more generic collapse/expand shapes with pack/unpack ops?

Maybe.. not sure... I havent though deeply about it. These unit-dimensions we definitely should be able to cause they are just artifacts.

^ agreed, unit dims should always be handled as best as possible at every layer - though we should squash as many as we can early on with mechanisms like const-eval/IPO/global hoisting/etc and our analysis during progressive lowering through stream it's possible for dynamic dimensions to become 1 and such all the way up to codegen and we can't rely on them having been completely eliminated at flow/above.