LinalgExt ops don't support fusion
IanWood1 opened this issue · comments
IREE currently lacks support for fusing LinalgExt
operations with either other LinalgExt
operations or standard Linalg
operations..The current fusion implementation relies on indexing maps to determine which operations can be successfully fused. See the example below:
%3 = tensor.empty() : tensor<4x1xi32>
%4 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%expanded : tensor<4x1xi64>) outs(%3 : tensor<4x1xi32>) {
^bb0(%in: i64, %out: i32):
%10 = arith.trunci %in : i64 to i32
linalg.yield %10 : i32
} -> tensor<4x1xi32>
%5 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(false) ins(%expanded_0, %4 : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>) outs(%2 : tensor<8192x16x8x128xf32>) {
^bb0(%arg5: f32, %arg6: f32):
iree_linalg_ext.yield %arg5 : f32
} -> tensor<8192x16x8x128xf32>
This results in the following dispatch formation (note no fusion):
%3 = tensor.empty() : tensor<4x1xi32>
%4 = flow.dispatch.region -> (tensor<4x1xi32>) {
%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%expanded_0 : tensor<4x1xi64>) outs(%3 : tensor<4x1xi32>) {
^bb0(%in: i64, %out: i32):
%11 = arith.trunci %in : i64 to i32
linalg.yield %11 : i32
} -> tensor<4x1xi32>
flow.return %10 : tensor<4x1xi32>
}
%5 = flow.dispatch.region -> (tensor<8192x16x8x128xf32>) {
%10 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(false) ins(%expanded, %4 : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>) outs(%2 : tensor<8192x16x8x128xf32>) {
^bb0(%arg5: f32, %arg6: f32):
iree_linalg_ext.yield %arg5 : f32
} -> tensor<8192x16x8x128xf32>
flow.return %10 : tensor<8192x16x8x128xf32>
}
Immediate solution
Adding functionality to FormDispatchRegions to get indexing maps for specific LinalgExt
ops. This would be a quick and easy way to get indexing maps for specificLinalgExt
ops. Also, it would lay the groundwork for a long-term solution.
Long term solution
Implement. Add toLinalgInterfaces
for LinalgExt opsTilingInterface
- ...
include "mlir/Interfaces/DestinationStyleOpInterface.td"
Linalg TilingInterfaceImpl
LinalgExt TilingInterfaceImpl
Thanks @IanWood1 for capturing this. Just to amend your long term solution. We will probably end up adding this to TilingInterface
which already has the notion of the iteration spaces. TBD though.