Handling reshape propagation for attention ops.
MaheshRavishankar opened this issue · comments
Recently we saw a use case for propagating reshapes across attention ops the same way we propagate reshapes across Linalg ops. For now we added a one-off folder pattern (d2ca774) that mimics the end-state, but we should be able to reuse some of the same techniques as we have for Linalg ops.
To provide some context, this is the input IR that we are looking at
%attention = iree_linalg_ext.attention {
indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%arg0, %arg1, %arg2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16)
outs(%empty : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
%split = arith.divsi %d0, %c2 : index
%expanded = tensor.expand_shape %attention [[0, 1], [2], [3]] output_shape[2, %split, %d1, %d4]
: tensor<?x?x?xf16> into tensor<2x?x?x?xf16>
If we increase the dimensionality of the attention op, we could make the attention op generate the expanded output shape directly to get
%expanded_arg0 = tensor.expand_shape %arg0 [[0, 1] ,[2], [3]] ...
%expanded_arg1 = tensor.expand_shape %arg1[[0, 1], [2], [3]] ...
%expanded_arg2 = tensor.expand_shape %arg2[[0, 1], [2], [3]] ...
%expanded_empty = tensor.expand_shape %empty [[0, 1], [2], [3]]
%expanded_attention = iree_linalg_ext.attention {
indexing_maps = [affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, d1, d2)>,
affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, d3, d2)>,
affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, d3, d4)>,
affine_map<(d0, d00, d1, d2, d3, d4) -> (d0, d00, d1, d4)>]}
ins(%expanded_arg0, %expanded_arg1, %expanded_arg2, %arg3 : tensor<?x?x?xf16>, tensor<?x?x?xf16>, tensor<?x?x?xf16>, f16)
outs(%expanded_empty : tensor<?x?x?xf16>) -> tensor<?x?x?xf16>
This is essentially similar to what is done in the foldReshapeByExpansion
transformation on Linalg ops.
For now we can borrow a lot from the implementation there, and essentially replicate this in IREE to be able to apply it to LinalgExt
ops. (The pie-in-the-sky goals of LinalgExt
ops is to move these into MLIR, but thats for a later time). The load bearing piece in the implementation for LinalgOp
s is the ExpansionInfo
. It takes the reassociation maps of the consumer expand_shape
operation (as well as the source collapsed and expanded shapes). This information is then used to compute
(a) if the op is expandable (here)
(b) the indexing map in the expanded op for every indexing map in the original op (here)
(c) The type of the operand in the expanded op for the type of the operand in the original op and indexing map used to access that operand (here)
(d) The reassociation indices to be used for the expand_shape
that has to be generated with the original operands of the attention op as source (here)
This logic needs to be replicated in IREE (for now) and used to generate the expanded attention op the same way the LinalgOp
is expanded to higher dimensions here
Once this expansion is done it unlocks more fusion opportunities. For example, after the reshape is propagated "up" through the attention op it can then more easily fuse with the transpose operation here
cc @Groverkss FYI since you were interested.