iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Add a lowering of `vector.interleave` to `vector.shuffle`

bjacob opened this issue · comments

In llvm integrate #17330 we have to locally revert llvm/llvm-project#89131 because it causes vector.interleave to be created instead of vector.shuffle, and some GPU codegen backends expected vector.shuffle and are not handling vector.interleave.

llvm/llvm-project#89131 is good in itself though, as vector.interleave is more constrained than general vector.shuffle. We just need a lowering pattern from vector.interleave to vector.shuffle to be inserted into codegen pipelines. Then we will be able to drop the local revert of llvm/llvm-project#89131.

FYI @KoolJBlack @qedawkins @kuhar

Is this an mlir issue or iree issue? To me it sound like this should go to vector-to-spirv and vector-to-llvm?

I was thinking vector-to-vector, rewriting vector.interleave to vector.shuffle. This way, only a single pattern is needed, not backend-specific, and by construction we know current backends are happy with vector.shuffle since that is what they currently get. It can go in this existing file:

https://github.com/llvm/llvm-project/blob/3bde7983986d8ce637f6bb506860859249787751/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp#L4

Then also a IREE-side change to insert that new pattern into codegen pipelines; and also IREE-side, remembering to drop the revert of llvm/llvm-project#89131 at the following LLVM integrate.

Oh OK, so this pattern is already there, just need to add add it to iree pipelines. Makes sense.

no no, a file is there but it only contains an unrelated UnrollInterleaveOp pattern. The pattern that we need here does not exist yet, it needs to be created.

On the spir-v side, I don't think there's any better lowering we could use anyway, spir-v has its own shuffle ops.

@kuhar @qedawkins , does this look like what we discussed? llvm/llvm-project#91800
Then I'm looking on the IREE side where to put this in the SPIRV pipeline. Maybe around

void runOnOperation() override {
MLIRContext *context = &getContext();
auto funcOp = getOperation();
bool emitIntegerDotProdOps = supportsIntegerDotProductOps(funcOp);
// First apply vectorization to generate vectors of the original tensor
// shape for tensor.pad ops.
{
RewritePatternSet patterns(context);
// Pull in additional vectorization patterns in IREE.
populateVectorizePadPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
debugPrint(funcOp, "after vectorizing tensor.pad");
// Special peephole optimizations to clean up IR before further processing.
{
RewritePatternSet patterns(context);
// Pull in patterns to shuffle broadcast/transpose ops around in order to
// cancel them or embed into contract ops. Embedding in the flexible
// contract ops will help to sustain the structure through various
// transformations.
vector::populateVectorReductionToContractPatterns(patterns);
// Pull in patterns to canonicalize transfer ops.
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
// Fold consumer add ops into the contraction op itself.
vector::ContractionOp::getCanonicalizationPatterns(patterns, context);
// Fold transpose ops if possible as we cannot unroll it later.
vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
debugPrint(funcOp, "after peephole optimization");
// High dimension contraction can appear after vectorizing ops like 1-D
// convolution. Those 1-D convolution ops typically have a leading unit
// batch dimension. Try to drop that to map to matmul dimensions better.
SmallVector<vector::ContractionOp> contractOps;
funcOp.walk([&](vector::ContractionOp op) {
if (op.getIteratorTypes().size() > 3)
contractOps.push_back(op);
});
for (vector::ContractionOp op : contractOps) {
OpBuilder builder(op);
IRRewriter rewriter(builder);
auto result = vector::castAwayContractionLeadingOneDim(
op, /*maskingOp=*/nullptr, rewriter);
if (succeeded(result)) {
rewriter.replaceOp(op, *result);
}
}
debugPrint(funcOp, "after trimming contract leading unit dims");
// Fold tensor.extract_slice/insert_slice ops into transfer ops. This helps
// to remove those tensor slice ops so that we can enable further vector op
// transformations.
{
RewritePatternSet patterns(context);
vector::TransferReadOp::getCanonicalizationPatterns(patterns, context);
vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context);
populateVectorTransferTensorSliceTransforms(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
debugPrint(funcOp, "after folding tensor extract/insert slice ops");
// Lower vector.multi_dimension early if any operand is a transpose op.
// The lowering itself generates transpose ops. This helps to cancel
// transpose ops. vector.multi_reduction is arguably a higher level op and
// the lowering also unrolls the multi_reduction op, so it makes sense to
// happen before normal unrolling.
{
SmallVector<Operation *> reductionOps;
funcOp.walk([&](vector::MultiDimReductionOp reductionOp) {
if (llvm::any_of(reductionOp->getOperands(), [](Value operand) {
return operand.getDefiningOp<vector::TransposeOp>();
}))
reductionOps.push_back(reductionOp);
return WalkResult::advance();
});
RewritePatternSet patterns(context);
vector::populateVectorMultiReductionLoweringPatterns(
patterns, vector::VectorMultiReductionLowering::InnerParallel);
if (failed(applyOpPatternsAndFold(reductionOps, std::move(patterns)))) {
funcOp.emitOpError("vector lowering failed");
return signalPassFailure();
}
}
debugPrint(funcOp, "after lowering multi reduction ops");
// Prepare for SPIR-V integer dot product lowering.
if (emitIntegerDotProdOps) {
RewritePatternSet patterns(context);
vector::populateVectorContractCanonicalizeMatmulToMMT(
patterns, detectI8ToI32Matmul);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
debugPrint(funcOp, "after preparing for SPIR-V dot product lowering");
}
// Then unroll vectors to native vector size. We try to use 128-bit
// vectors for memory access and 4/2/1 vector sizes for computation.
{
RewritePatternSet patterns(context);
populateVectorUnrollPatterns(patterns, emitIntegerDotProdOps);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
debugPrint(funcOp, "after unrolling vector ops");
// Lower reduction-unrolled vector contract ops. Such contract ops have
// their reduction dimensions all be one, so we can convert them into
// elementwise ops.
{
RewritePatternSet patterns(context);
auto options =
vector::VectorTransformsOptions().setVectorTransformsOptions(
vector::VectorContractLowering::ParallelArith);
vector::populateVectorContractLoweringPatterns(patterns, options);
// The pattern can generate transpose ops. Try to fold it if possible to
// avoid lowering them into extract/insert later.
vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
// It also generates broadcast/extract ops. Clean up them too.
vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
debugPrint(funcOp, "after lowering size-1 reduction contract ops");
// Now lower vector transpose given we have handled vector patterns that may
// generate transpose ops in previous steps. This converts transpose ops
// into extract and insert pairs, in preparation of further transformations
// to canonicalize/cancel.
{
RewritePatternSet patterns(context);
auto options =
vector::VectorTransformsOptions().setVectorTransposeLowering(
vector::VectorTransposeLowering::EltWise);
vector::populateVectorTransposeLoweringPatterns(patterns, options);
vector::populateVectorShapeCastLoweringPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
debugPrint(funcOp, "after lowering transpose ops");
// Next run canonicalization to cast away leading size-1 dimensions. They
// can be generated from vector unrolling and generally cause issues to
// cancel corresponding read/write or insert/extract op pairs. This also
// need to happen before hoisting, where we would make certain vectors loop
// carried. Once that's done, it's hard to handle the leading size-1
// dimensions across regions.
{
RewritePatternSet patterns(context);
// We need to pull in casting way leading one dims to allow cancelling
// some read/write ops.
vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
// We may have vector.insert_strided_slice inserting 1-D native vectors
// into n-D larger vectors with the above. Break that down too. This is a
// companion transformation of unrolling.
vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
patterns);
vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
// Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
// them up.
vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
vector::TransferReadOp::getCanonicalizationPatterns(patterns, context);
vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context);
populateVectorTransferTensorSliceTransforms(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
debugPrint(funcOp, "after trimming leading unit dims");
// Lower vector reduction to SPIR-V integer dot product.
if (emitIntegerDotProdOps) {
RewritePatternSet patterns(context);
populateVectorReductionToSPIRVDotProductPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
debugPrint(funcOp, "after lowering to SPIR-V dot product");
}
}
};
? where exactly in that big function?

Otherwise here also seems fine:

vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(

Nice, that looks great to me. In terms of the IREE side, I think some combination of adding the vector.interleave n-d -> 1-d here:

vector::TransposeOp::getCanonicalizationPatterns(patterns, context);

And then the interleave to shuffle can either go in the same place, or somewhere near here:

patterns, vector::VectorMultiReductionLowering::InnerParallel);

Basically after decomposing to 1d and before unrolling to 1/2/3/4 vector elements.

Otherwise here also seems fine:

vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(

Correct me if otherwise, but I was thinking it had to happen before unrolling to <= 4 elements? Unless interleave implements the unrolling interface already.

My hope would be that unrolling could break it down to source vectors in some cases already, but I haven't checked if it support the unrolling interface.

Basically after decomposing to 1d and before unrolling to 1/2/3/4 vector elements.

This is also a good option if it doesn't support unrolling.

Can we put it upstream in https://github.com/llvm/llvm-project/blob/e9f53e4095d8a8600b5c5d445c73e2d5a6f45abb/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp#L812 ?

I have updated the PR with that, thanks for the tip!

Nice, that looks great to me. In terms of the IREE side, I think some combination of adding the vector.interleave n-d -> 1-d here:

@kuhar @qedawkins I am just trying to fix the immediate issue that is forcing us to carry a local LLVM revert here. This issue only seems to involve 1D vectors as far as I have seen so far.

Makes sense, then I would say any point before here is likely to work:

populateVectorUnrollPatterns(patterns, emitIntegerDotProdOps);

So should the upstream PR do it in VectorToSPIRV.cpp#L812 or not?

If it should do that upstream, then nothing needs to be done on the IREE side, right?

I'm guessing we'll still need to do something on the IREE side because in addition to the requirement that all vectors must be 1-d, on SPIR-V they must also be <= 4 elements wide, and the unrolling that goes to <= 4 elements happens in IREE right now. So I'm thinking VectorToSPIRV will be too late.

but that sounds like solving a more general problem than my immediate concern of flattening our llvm integration. The upstream change that we have to carry a local revert of isn't changing the number of elements in a vector, or the rank of a vector. It's just changing vector.shuffle-on-1d-vector to vector.interleave-on-1d-vector.