Packing the result of broadcast introduces huge memory allocation
hanhanW opened this issue · comments
#map = affine_map<(d0, d1, d2) -> (d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
util.func public @turbine_llm_mmtfp_3d_8640_3200_f32f16(%arg0: tensor<?x?x3200xf32>, %arg1: tensor<8640x3200xf16>) -> tensor<?x?x8640xf32> {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = tensor.dim %arg0, %c0 : tensor<?x?x3200xf32>
%dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x3200xf32>
%0 = tensor.empty(%dim) : tensor<?x8640x3200xf16>
%1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<8640x3200xf16>) outs(%0 : tensor<?x8640x3200xf16>) {
^bb0(%in: f16, %out: f16):
linalg.yield %in : f16
} -> tensor<?x8640x3200xf16>
%2 = tensor.empty(%dim, %dim_0) : tensor<?x?x8640xf32>
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?x8640xf32>) -> tensor<?x?x8640xf32>
%4 = linalg.batch_matmul_transpose_b ins(%arg0, %1 : tensor<?x?x3200xf32>, tensor<?x8640x3200xf16>) outs(%3 : tensor<?x?x8640xf32>) -> tensor<?x?x8640xf32>
util.return %4 : tensor<?x?x8640xf32>
}
Coming from #17022 and discord discussion, we are seeing pack(broadcast) -> mmt4d
pattern. This is bad because we will allocate a big buffer for broadcast -> pack
dispatch, and pass the result to mmt4d
kernel. What's happening today is:
Set encodings on matmul operands:
%bcast = linalg.generic ins(%src) ... // broadcast for batch dimension
%lhs = set_encoding(%original_lhs)
%rhs = set_encoding(%bcast)
%gemm = linalg.batch_matmul ins(%lhs, %rhs) ...
If we write it in a materialized form, it is:
%bcast = linalg.generic ins(%src) ... // broadcast for batch dimension
%lhs = tensor.pack %original_lhs
%rhs = tensor.pack %bcast
%gemm = linalg.batch_mmt4d ins(%lhs, %rhs) ...
The dispatch formation results in
dispatch {
%bcast ...
%rhs = tensor.pack %bacst
return %rhs
}
This is why we have big memory allocation. However, it is not a hard limit for data-tiling path. What we can do here is set encodings on the source of broadcast. This allows us to swap broadcast
and set_encoding/tensor.pack
op, which results in
%packed_src = tensor.pack %src
%rhs = linalg.generic ins(%packed_src) ... // broadcast for batch dimension
%lhs = tensor.pack %original_lhs
%gemm = linalg.mmt4d ins(%lhs, %rhs) ...
We should be able to make dispatch formation result in
dispatch {
tensor.pack %src
}
dispatch {
%rhs = linalg.generic ...
%gemm = linalg.batch_mmt4d ins(%lhs, %rhs) ...
}
In this context, the memory allocation is much smaller because we don't allocate it with batch dimension. The further action item is about how we codegen broadcast + batch_mmt4d
dispatch. It can be achieved like what we have for batch_mmt4d codegen. We tile the batch dimension with size=1; leverage it to mmt4d codegen/ukernels.
After TileAndFuse with batch_size=1:
for (int i = 0; i < batch_size; i += 1) {
%lhs_slice = tensor.extract_slice %lhs …
%rhs_slice = linalg.generic(%rhs_wo_broadcast) … -> tensor<1xN0xK0xN1xK1xf16>
%res = batch_mmt4d(%lhs_slice, %rhs_slice)
}
After batch_mmt4d -> mmt4d decomposition:
for (int i = 0; i < batch_size; i += 1) {
%lhs_slice = tensor.extract_slice %lhs … -> tensor<1xM0xK0xM1xK1xf16>
%rhs_slice = linalg.generic(%rhs_wo_broadcast) … -> tensor<1xN0xK0xN1xK1xf16>
%lhs_wo_batch = tensor.extract_slice %lhs_slice … -> tensor<M0xK0xM1xK1xf16>
%rhs_wo_batch = tensor.extract_slice %rhs_slice … -> tensor<N0xK0xN1xK1xf16>
%res = mmt4d(%lhs_wo_batch, %rhs_wo_batch)
}
With this flow, we should be able to get rid of huge memory allocation.