iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Flow] Enable reshape propagation through tensor.pad

Max191 opened this issue · comments

When trying to fuse tensor.pad with producers, reshapes can be blocking fusion unnecessarily. The following IR is an example of this from VAE:

  %168 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%166, %167, %expanded_332, %expanded_333 : tensor<32x16x256x256xf32>, tensor<32xf32>, tensor<32x16xf32>, tensor<32x16xf32>) outs(%157 : tensor<32x16x256x256xf32>) {
  ^bb0(%in: f32, %in_618: f32, %in_619: f32, %in_620: f32, %out: f32):
    %384 = arith.divf %in_618, %cst_97 : f32
    %385 = arith.addf %384, %cst_93 : f32
    %386 = math.rsqrt %385 : f32
    %387 = arith.mulf %in, %386 : f32
    %388 = arith.mulf %387, %in_619 : f32
    %389 = arith.addf %388, %in_620 : f32
    %390 = arith.negf %389 : f32
    %391 = math.exp %390 : f32
    %392 = arith.addf %391, %cst_91 : f32
    %393 = arith.divf %cst_91, %392 : f32
    %394 = arith.mulf %393, %389 : f32
    linalg.yield %394 : f32
  } -> tensor<32x16x256x256xf32>
  %collapsed_334 = tensor.collapse_shape %168 [[0, 1], [2], [3]] : tensor<32x16x256x256xf32> into tensor<512x256x256xf32>
  %padded_335 = tensor.pad %collapsed_334 low[0, 1, 1] high[0, 1, 1] {
  ^bb0(%arg3: index, %arg4: index, %arg5: index):
    tensor.yield %cst_64 : f32
  } : tensor<512x256x256xf32> to tensor<512x258x258xf32>

The tensor.collapse_shape does not touch the collapsed dimensions, so if the reshape were propagated through the tensor.pad op, then the two ops could fuse into a dispatch.

One way to do this would be to add reshape propagation patterns for tensor.pad (like https://github.com/llvm/llvm-project/blob/af31883341a122a7285e9b4f0a034470024021eb/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp#L922), but it may be tricky to manage the propagations.