Allow eliding unused results from dispatches as canonicalizations
MaheshRavishankar opened this issue · comments
Computations like argmax
when lowered to Linalg need to have two result values, one for tracking the maximum value, and one for the position of the maximum. Even though the maximum value might not be used later on, it is needed for the Linalg operation computation. The dispatch region canonicalizations remove results that are not used, so for a dispatch region with an argmax
where the max value isnt used, there is only one result buffer. The bufferization ends up creating a new stack allocation for tracking the maximum value. This leads to stack overflow (see #8411 (comment)) .
The hal.executable
for one such case is shown below.
func @main_dispatch_89() {
%cst = arith.constant -3.40282347E+38 : f32
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%c2097152 = arith.constant 2097152 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c2097152) alignment(64) : !flow.dispatch.tensor<readonly:512x512x32xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:512x512xi32>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [512, 512, 32], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:512x512x32xf32> -> tensor<512x512x32xf32>
%3 = linalg.init_tensor [512, 512] : tensor<512x512xf32>
%4 = linalg.init_tensor [512, 512] : tensor<512x512xi32>
%5 = linalg.fill ins(%c0_i32 : i32) outs(%4 : tensor<512x512xi32>) -> tensor<512x512xi32>
%6 = linalg.fill ins(%cst : f32) outs(%3 : tensor<512x512xf32>) -> tensor<512x512xf32>
%7:2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%2 : tensor<512x512x32xf32>) outs(%5, %6 : tensor<512x512xi32>, tensor<512x512xf32>) {
^bb0(%arg0: f32, %arg1: i32, %arg2: f32):
%8 = linalg.index 2 : index
%9 = arith.index_cast %8 : index to i32
%10 = arith.cmpf ogt, %arg0, %arg2 : f32
%11 = arith.select %10, %arg0, %arg2 : f32
%12 = arith.select %10, %9, %arg1 : i32
linalg.yield %12, %11 : i32, f32
} -> (tensor<512x512xi32>, tensor<512x512xf32>)
flow.dispatch.tensor.store %7#0, %1, offsets = [0, 0], sizes = [512, 512], strides = [1, 1] : tensor<512x512xi32> -> !flow.dispatch.tensor<writeonly:512x512xi32>
return
}
Since %7#1
is not a result of the dispatch, there is no flow.dispatch.tensor.store
to write that result in, which leads to stack allocations after bufferization.
Dropping the elision of unused results from dispatch regions solves the issue, but is more of a WAR w.r.t current lowering of argmax
kind of operations to Linalg. Maybe a better representation (or using a new LinalgExt op) might address this issue in future. For now using the WAR to address compilation failures.
This bug is a tracking bug to re-enable this when the backends can handle the computation without requiring stack allocations
https://reviews.llvm.org/D123632 is an attempt to make linalg.generic
ops drop results that are unused, but it doesnt address the specific case here.
Hey folks, this seems to block running a key BERT workload from Torch-MLIR. Can we prioritize it?
What's blocking here? Can you give more details ?
All backends hit the same issue. We allow bounded allocation in CPU and cuda backend. This can not simply apply to vmvx backend, because vmvx explicitly disallow alloca ops. To make it work on vmvx, we might have to address this issue.
@MaheshRavishankar see #9156 and #9155
Closing this particular issue. It was meant to track a canonicalization that was supposed to be turned off. It isnt being turned off anymore :)