[Flow] Enable reshape propagation through tensor.pad
Max191 opened this issue · comments
When trying to fuse tensor.pad
with producers, reshapes can be blocking fusion unnecessarily. The following IR is an example of this from VAE:
%168 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%166, %167, %expanded_332, %expanded_333 : tensor<32x16x256x256xf32>, tensor<32xf32>, tensor<32x16xf32>, tensor<32x16xf32>) outs(%157 : tensor<32x16x256x256xf32>) {
^bb0(%in: f32, %in_618: f32, %in_619: f32, %in_620: f32, %out: f32):
%384 = arith.divf %in_618, %cst_97 : f32
%385 = arith.addf %384, %cst_93 : f32
%386 = math.rsqrt %385 : f32
%387 = arith.mulf %in, %386 : f32
%388 = arith.mulf %387, %in_619 : f32
%389 = arith.addf %388, %in_620 : f32
%390 = arith.negf %389 : f32
%391 = math.exp %390 : f32
%392 = arith.addf %391, %cst_91 : f32
%393 = arith.divf %cst_91, %392 : f32
%394 = arith.mulf %393, %389 : f32
linalg.yield %394 : f32
} -> tensor<32x16x256x256xf32>
%collapsed_334 = tensor.collapse_shape %168 [[0, 1], [2], [3]] : tensor<32x16x256x256xf32> into tensor<512x256x256xf32>
%padded_335 = tensor.pad %collapsed_334 low[0, 1, 1] high[0, 1, 1] {
^bb0(%arg3: index, %arg4: index, %arg5: index):
tensor.yield %cst_64 : f32
} : tensor<512x256x256xf32> to tensor<512x258x258xf32>
The tensor.collapse_shape
does not touch the collapsed dimensions, so if the reshape were propagated through the tensor.pad
op, then the two ops could fuse into a dispatch.
One way to do this would be to add reshape propagation patterns for tensor.pad (like https://github.com/llvm/llvm-project/blob/af31883341a122a7285e9b4f0a034470024021eb/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp#L922), but it may be tricky to manage the propagations.