iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Data tiling encodings: simplify `round_dims_to` and implement narrower padding for narrow dimensions

bjacob opened this issue · comments

Context: #17545, the rebasing of #16890 past #17077, has been difficult. It comes down to two different fields of Encoding attributes that have partially overlapping and fuzzy semantics. This Issue is about resolving all that, and along the way, completing an aspect of the intended design (to avoid over-allocating certain buffers) that was not yet implemented.

The two fields of Encodind attribute that we are talking about are

  1. The matmul_narrow_{M,N} fields, which preexist and which #17545 is concerned with.
  2. The round_dims_to field introduced in #17077.

The current semantics of these fields are that:

  1. The matmul_narrow_{M,N} fields are just hints that this matmul has some narrow dimension, which may affect tile size selection (including for matrix operands where this narrow dimension doesn't participate; for instance, a narrow-M case like vecmat can still lead to a different tile choice for the RHS matrix, whose shape does not involve the M dimension).
  2. The round_dims_to field is an array attribute, enumerating the dimensions in the order of the iterators, e.g. [B,] M, N, K. It informs Stream of the maximum tile sizes that this matmul may need padding of its operands for, and it's used to ensure that buffer allocations are large enough to accomodate that padding.

At the moment, round_dims_to array entries are all initialized to the same padFactor value given as a pass option. So the potential benefit of having this as an array (adjusting this padding amount for narrow dimensions) is not yet reaped, while the cost (having to correctly handle this array in things like what #17545 is doing) is being paid already.

In fact, if we started populating round_dims_to with narrower values for narrow dimensions, we would be encoding the information of "this is a narrow dimension" twice, in round_dims_to and in the matmul_narrow_{M,N} attribute.

Proposal:

  1. Rename round_dims_to to max_padding. This makes it clear what it is used for.
  2. Change max_padding to be a single integer attribute, not an array. Its meaning is "the general-case padding amount, outside of any narrow-dimension cases".
  3. In places that were currently consuming the getRoundDimsToArray() value, also check the matmul_narrow_{M,N} attribute. If either is defined, let that override the max_padding value, just rounded up to the next power of two. Example: if max_padding=16 and matmul_narrow_M=3, round matmul_narrow_M to the next power of two, which is 4, and use that instead of max_padding for the M dimension.
  4. In SetEncoding, there was another place where the value 16 was hardcoded:
    const int64_t kNarrowThreshold = 16;
    . This should have the pass's padFactor option value passed down to there, and use that instead.
  5. In CPUMaterializeEncodingPass.cpp, notice how each return case in enumerateMatmulTile* functions returns a list of TileMxNxK triples where the M values (the first entries in the triple) are powers of two including all smaller powers of two down to 1. With the changes that we are discussing here, this is becoming a hard requirement: this is what ensures that the M dimension never gets rounded to more than just the next power of two. So, in the top-level enumerateMatmulTileMxNxK function, before returning the value, we should probably assert that that requirement is satisfied, to catch it if we even forgot about it.
  6. There is currently a discrepancy between the current round_dims_to value 16 and the fact that in some cases in enumerateMatmulTileMxNxK we return TileMxNxK values exceeding 16. These tiles are being discarded at the moment since #17077 was merged:
    if (tile.M > mUB || tile.N > nUB || tile.K > kUB) {
    LLVM_DEBUG(llvm::dbgs() << "[" << DEBUG_TYPE << "]: tile (";
    llvm::interleaveComma(
    ArrayRef<int64_t>{tile.M, tile.N, tile.K}, llvm::dbgs());
    llvm::dbgs()
    << ") is skipped because it is not valid for upper_bound (";
    llvm::interleaveComma(ArrayRef<int64_t>{mUB, nUB, kUB},
    llvm::dbgs());
    llvm::dbgs() << ")\n");
    continue;
    . Now that (thanks to step 3 above) we are not rounding narrow dimensions by much anymore, it doesn't cost as much anymore to increase that round_dims_to (now called max_padding) value a big. So my suggested trade-off would be: increase the padFactor used in SetEncoding from 16 to 32, which is needed for the tiles that we really care about; and in CPUMaterializeEncoding pass, for all tiles enumerated in enumerateMatmulTile*, clamp all values to maximum 32 so that these tiles don't get discarded anymore. The above-linked code doing debug-logging and continue; can then become an error (propagate to caller).

@bjacob and @lialan can we discuss this a bit more tomorrow. I know this was filed a couple of weeks ago, but I only got around to looking at this (following from the PR that was sent out). I just want to clarify somethings with the change from round_dims_to from an array to a single scalar. In my mind having an array of round_dims_to makes more sense, and I'd rather drop the matmul_narrow_M/N cause that is not generic enough IMO.

Superseded by #17729.