openxla / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Permute elementwise operations to make fusions better

MaheshRavishankar opened this issue · comments

Post elementwise fusion Bert model has quite a few cases of the following pattern

%gemm = linalg.matmul
%generic = linalg.generic {
    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
    iterator_types = ["parallel", "parallel"]}
    ins(%gemm, ... : ) outs(...) {...}

These dont get fused into a dispatch region since the condition for tile + fuse is that the use of the result of gemm is accessed using an identity indexing map (i.e. affine_map<(d0, d1) -> (d0, d1)>) in the consumer. Here the indexing map used in the consumer is affine_map<(d0, d1) -> (d1, d0)> which is a permutation.

However the following sequence is semantically equivalent to the above

%gemm = linalg.matmul
%generic = linalg.generic {
    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
    iterator_types = ["parallel", "parallel"]}
    ins(%gemm, ... : ) outs(...) {...}

In this form the ops will get tiled and fused. This is a simple loop interchange at linalg level that could be done as a pre-processing post elementwise fusion and before dispatch region formation. The transformation to implement interchange already exists here
https://github.com/llvm/llvm-project/blob/3e678cb77264907fbc2899c291ce23af308073ff/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h#L239

The transformation/pass has to apply the interchange on generic ops that
a) have a gemm/convolution/reduction operation as a producer
b) the result of the producer is accessed in the generic op using a permutation instead of a identity.

By my count this will reduce the number of dispatches by 12 in the MiniLM model.

There is another condition to fuse matmul + genericOp; the input and output indexing maps should be identical. The matmul output is 128x384 and the genericOp adds a constant and transposes it to 384x128. There is a comment about bufferization and vectorization. But at 10,000 ft, with an additional buffer for 384x128, the loop should be able to transpose the matmul output into the new buffer.

There is another condition to fuse matmul + genericOp; the input and output indexing maps should be identical. The matmul output is 128x384 and the genericOp adds a constant and transposes it to 384x128. There is a comment about bufferization and vectorization. But at 10,000 ft, with an additional buffer for 384x128, the loop should be able to transpose the matmul output into the new buffer.

Yeah, would be good to see what happens in practice. On the CPU side all the dispatches are now vectorized. So I would expect there to be no additional stacks. The CUDA side works slightly differently, we might get some help to address those as well. Not sure about the vulkan side. That can be tricky. We can make the change and see what the fall out on the backends will be. At the very least we can guard these changes at the flow level with a flag till all backends can handle these flags to be on by default.

It looks like we need some checking for vulkan.
//iree/test/e2e/models:check_vulkan-spirv_vulkan_bert_encoder_unrolled_fake_weights.mlir FAILED TO BUILD

For the record, I discussed the issue in more detail with Mahesh. The reason why we want to swap the input and output indexing maps for the genericOp is that keeping the input indexing map as identity would keep the matmul code as is. If we have a permuted indexing map for an input for genericOp, this will make the matmul code transposed, which we don't want to see.

The main work for the fusion logic is done by #9103.

There are remaining issues for the ARM (with mm4td) and CUDA backends, and they will be handled separately.