iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Packing the result of broadcast introduces huge memory allocation

hanhanW opened this issue · comments

#map = affine_map<(d0, d1, d2) -> (d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
module {
  util.func public @turbine_llm_mmtfp_3d_8640_3200_f32f16(%arg0: tensor<?x?x3200xf32>, %arg1: tensor<8640x3200xf16>) -> tensor<?x?x8640xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?x?x3200xf32>
    %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x3200xf32>
    %0 = tensor.empty(%dim) : tensor<?x8640x3200xf16>
    %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1 : tensor<8640x3200xf16>) outs(%0 : tensor<?x8640x3200xf16>) {
    ^bb0(%in: f16, %out: f16):
      linalg.yield %in : f16
    } -> tensor<?x8640x3200xf16>
    %2 = tensor.empty(%dim, %dim_0) : tensor<?x?x8640xf32>
    %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?x8640xf32>) -> tensor<?x?x8640xf32>
    %4 = linalg.batch_matmul_transpose_b ins(%arg0, %1 : tensor<?x?x3200xf32>, tensor<?x8640x3200xf16>) outs(%3 : tensor<?x?x8640xf32>) -> tensor<?x?x8640xf32>
    util.return %4 : tensor<?x?x8640xf32>
  }

Coming from #17022 and discord discussion, we are seeing pack(broadcast) -> mmt4d pattern. This is bad because we will allocate a big buffer for broadcast -> pack dispatch, and pass the result to mmt4d kernel. What's happening today is:

Set encodings on matmul operands:

%bcast = linalg.generic ins(%src) ... // broadcast for batch dimension
%lhs = set_encoding(%original_lhs)
%rhs = set_encoding(%bcast)
%gemm = linalg.batch_matmul ins(%lhs, %rhs) ...

If we write it in a materialized form, it is:

%bcast = linalg.generic ins(%src) ... // broadcast for batch dimension
%lhs = tensor.pack %original_lhs
%rhs = tensor.pack %bcast
%gemm = linalg.batch_mmt4d ins(%lhs, %rhs) ...

The dispatch formation results in

dispatch {
  %bcast ...
  %rhs = tensor.pack %bacst
  return %rhs
}

This is why we have big memory allocation. However, it is not a hard limit for data-tiling path. What we can do here is set encodings on the source of broadcast. This allows us to swap broadcast and set_encoding/tensor.pack op, which results in

%packed_src = tensor.pack %src
%rhs = linalg.generic ins(%packed_src) ... // broadcast for batch dimension
%lhs = tensor.pack %original_lhs

%gemm = linalg.mmt4d ins(%lhs, %rhs) ...

We should be able to make dispatch formation result in

dispatch {
  tensor.pack %src
}

dispatch {
  %rhs = linalg.generic ...
  %gemm = linalg.batch_mmt4d ins(%lhs, %rhs) ...
}

In this context, the memory allocation is much smaller because we don't allocate it with batch dimension. The further action item is about how we codegen broadcast + batch_mmt4d dispatch. It can be achieved like what we have for batch_mmt4d codegen. We tile the batch dimension with size=1; leverage it to mmt4d codegen/ukernels.

After TileAndFuse with batch_size=1:

for (int i = 0; i < batch_size; i += 1) {
  %lhs_slice = tensor.extract_slice %lhs …
  %rhs_slice = linalg.generic(%rhs_wo_broadcast) … -> tensor<1xN0xK0xN1xK1xf16>
  %res = batch_mmt4d(%lhs_slice, %rhs_slice)
} 

After batch_mmt4d -> mmt4d decomposition:

for (int i = 0; i < batch_size; i += 1) {
  %lhs_slice = tensor.extract_slice %lhs … -> tensor<1xM0xK0xM1xK1xf16>
  %rhs_slice = linalg.generic(%rhs_wo_broadcast) … -> tensor<1xN0xK0xN1xK1xf16>
  %lhs_wo_batch = tensor.extract_slice %lhs_slice … -> tensor<M0xK0xM1xK1xf16>
  %rhs_wo_batch = tensor.extract_slice %rhs_slice … -> tensor<N0xK0xN1xK1xf16>
  %res = mmt4d(%lhs_wo_batch, %rhs_wo_batch)
}

With this flow, we should be able to get rid of huge memory allocation.