iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Slow Softmax at top of MobileBert/int8 profile

bjacob opened this issue · comments

Profiling MobileBert/int8/experimental(mmt4d)/dotprod, where matmuls themselves are relatively fast, makes the rest show more prominently in profiles.

According to Tracy, ~50% of time is being spent in a dispatch that appears to be a Softmax. At least it plausibly looks like one as it performs some table-lookups, a sum-reduction, then evaluates a degree-15 polynomial approximation of a math function and multiplies that together, like a Softmax would. And we know that MobileBert contains lots of Softmax.

TOSA "source" code of the slow loop

My own hand-deciphering of that TOSA code into pseudo-C (where I got to understand that it's evaluating a degree-15 polynomial, etc).

disassembly from Tracy:
image

We can see that it's scalar code, not SIMD. The x and w registers are scalar registers (64bit and 32bit respectively).

Getting this code to vectorize properly is likely to require some measure of explicit vectorization using at least ARM NEON intrinsics. The reason is that the efficient lowering of these fixed-point multiplications depends on fine details of the target ISA. As the target is ARM NEON, the fixed-point multiplications of the form

(int64(a) * int64(b)) >> 31

should be explicitly vectorized as

vqdmulhq_s32(a, b)

The TOSA code here is exactly as in this LIT tests , so that confirms it's a Softmax as imported from TFLite to TOSA: https://github.com/tensorflow/tensorflow/blob/f8e703259b83f23ef2115edcc2bc2eff0ead1300/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir#L1385-L1437

If there are gather ops (maybe from tosa.table?), the kernel won't get vectorized.

Do you have IR dumps for the kernel? We can extract the IR at Linalg level, see https://github.com/google/iree/blob/main/docs/developers/debugging/integration_correctness_issue_breakdown.md#narrow-down-the-repro

In short, we can get the IR after OutlineDispatchRegions. This would give us more ideas about what's the input for codegen.

The above tfl-to-tosa test shows that the tosa.table ops are part of the lowering of the softmax. As the actual TOSA code linked above shows the lookup tables have length 513, which is the mandated length of lookup tables for math functions in TOSA, they clearly must be the implementation of the exponential math function inside the softmax.

Just understood: the degree-15 polynomial here is the approximation of (1/(1+x)) - it's how this code is performing the division in the softmax. That's why the input to this polynomial, %1136, is itself the result of a sum reduction: the denominator is the sum of the previously computed (by LUT) exponentials.

The TFLite->TOSA lowering of Softmax is generated by this code.

small world - this bit references gemmlowp code I wrote years ago. I had above deciphered the constants 180/255 and 120/255 and hadn't bothered to simplify the fractions -- but 255=15*17 and simplifying away the factors of 15, then rescaling by 4 (change of fixedpoint format) leaves the traditional 48/17 and 32/17 coefficients of Newton-Raphson division. it all makes sense now :-)

If there are gather ops (maybe from tosa.table?), the kernel won't get vectorized.

That inspired me to try just dropping the tosa.table ops just to see if that was the only thing preventing vectorization here. But even without the tosa.table, the fixed-point multiplications here are still not vectorized, presumably because they involve multiplying int32 values into a int64 result before reducing that to the final int32 result of the fixed-point multiplication, and no pattern is currently kicking in to perform that optimization.

So there are 2 separate problems here --- the exp computation with tosa.table and the Newton-Raphson division with non-vectorized int64 arithmetic.

@bjacob I am assigning this to you for now since you are working on this AFAIK, but feel free to move it to me if you arent.

(I initially wrote this as a reply on #9170 but it better belongs here. Context: we noticed in #9170 that Softmax just became much faster than initially reported here. This is a study of the current (faster) code and what next steps could be taken from here. While faster than before, it is still 20% slower than just dequantizing Softmax, so it's still not a high bar to beat).

Thanks @hanhanW for the investigation in #9170 (comment). The IR here is the implementation of Softmax in TOSA, generated during TFLite-to-TOSA here.

Latencies for this benchmark on Moto Edge X30:

  • Before: 472 ms
  • Now: 258 ms
  • With dequantization of Softmax: 204 ms

So this is now much better than before -- but just dequantizing Softmax (ie making it use the float Softmax implementation) would still be another 20% speedup over this.

Here's how this currently looks in Tracy on Moto Edge X30:
timeline-after

  • main_dispatch_14 is unrelated

  • main_dispatch_59 is the computation of the int32 exponential values. It is implemented as two table-lookups producing an int16 value each, plus bit-operations combining those into the resulting int32 value. The above Tracy timeline shows it accounting for the majority of the cost of Softmax (now that the subsequent arithmetic in main_dispatch_61, we get to that below, is less bad than it used to be). Zooming in on main_dispatch_59 (MLIR TOSA source on the left, Aarch64 asm on the right): main_dispatch_59_after

  • main_dispatch_60 is the computation of the sum of those exponentials (with a suitable right-shift applied), which becomes the denominator in the division implementing the softmax in main_dispatch_61 below. It's only a small fraction of overall time according the the above timeline, so maybe not worth discussion. Here's view anyway:
    main_dispatch_60_after

  • main_dispatch_61 is the fixed-point division, where the numerator is the exponential value (from main_dispatch_59) and the denominator is the sum of exponentials (from main_dispatch_60). It is what just became much faster. Before, this was un-vectorized code. Now this is vectorized, but in a naive way: the 32bit fixed-point multiplications are still naively done as multiplying into int64 temporaries, so SIMD here is really only 2 lanes. It's still more than a 2x speedup due to how bad the non-SIMD code was (see disasm in #8974). But properly using fixed-point instructions would still be more than a 2x further speedup on top of that --- 2x increase in number of lanes, plus condensing multiple instructions (multiplication, right-shift...) into one fixed-point multiplication. That is the topic of Issue #9109. Here is the view into main_dispatch_61 at the moment:
    main_dispatch_61_after

In summary:

  • Now that main_dispatch_61 is getting vectorized, we could take the steps outlined in #9109 to make it use the right fixed-point multiplication instructions, and that would make main_dispatch_61 ~ 2x faster.
  • However, main_dispatch_61 is not the most expensive part -- main_dispatch_59 is.
  • There is no easy way to speed up main_dispatch_59 because table lookups are inherently not cheap. We can't meaninfully vectorize this (ARM NEON vector table lookup is limited to tables fitting in at most 64 bytes, and trying to combine multiple such instructions would result in worse performance).
  • It may be worth trying an arithmetic implementation of exponential, instead of table-lookup. But that is not cheap to try (high eng complexity, although one could reuse the impl from gemmlowp/fixedpoint).
  • Since #9109 is something that we want to do anyway because it affects much more (all rescales), maybe that's the first thing to do. However, from a narrow Softmax perspective, dequantizing still seems like the most bang for your buck: low eng effort, drop lots of complex IR and reliance on smart patterns applying, get 20% e2e speedup on top of the speedup observed here.
  • WDYT?

@rsuderman @bjacob do we need to lower exp as a table lookup? If we lowered it to math.exp, there are polynomial expansions of this operation (here). Doing a table lookup for exp all the way at TOSA level again seems like a premature optimization that should be avoided there (basically any arithmetic smarts at TOSA level is not phase-ordered correctly).

We should defer the decision about arithmetic smarts to later stages. We don't want to do it at TOSA -> Linalg stage. If we really want it in IREE, it can be done before fusion.

We're able to vectorize arith ops, but not table lookups. Can we lower the exp op to math.exp?

@rsuderman @bjacob do we need to lower exp as a table lookup? If we lowered it to math.exp, there are polynomial expansions of this operation (here).

That link is to an implementation of exp for the f32 data type. So that code path is what we're using when dequantizing Softmax ops as suggested above (as as said above, that dequantized-Softmax path, plus the dequantize and re-quantize around it, is still much faster than the current quantized Softmax even after the recent speedup discussed in #9170).

So yes we can just us that by dequantizing softmax, or we could have something even faster with an arithmetic impl of quantized exp.

The other dimension here is that these scalar table lookups are (even) more inefficient than they need to be, by a factor of > 2x for the following reason. The code means to look up 32bit exponential values, but apparently due to a TOSA limitation it can only perform table lookups of 16bits each -- so it ends up performing 2 lookups of 16bit each, plus bit ops to combine these two 16bit values into a 32bit value. See above disasm for main_dispatch_59.

That should be an easy fix: even if the TOSA spec wants to have lookups of only 16bit value, we could extend the MLIR TOSA dialect to support 32bit lookups.

@rsuderman @bjacob do we need to lower exp as a table lookup? If we lowered it to math.exp, there are polynomial expansions of this operation (here).

That link is to an implementation of exp for the f32 data type. So that code path is what we're using when dequantizing Softmax ops as suggested above (as as said above, that dequantized-Softmax path, plus the dequantize and re-quantize around it, is still much faster than the current quantized Softmax even after the recent speedup discussed in #9170).

So yes we can just us that by dequantizing softmax, or we could have something even faster with an arithmetic impl of quantized exp.

Would it be possible to keep the exp as is, and when lowering to LLVM convert the exp alone to use the fp32 path (and the polynomial expansion) when it is in scalar form?

RE 16-bit tosa lookup: that's a restriction of older/smaller ISAs that could only handle 16-bit element lookups (neon, riscv which only has vrgatherei16, etc). It's a layering violation that then when 32-bit lookups are available (avx2/sve) we are still forced down that path early. Would be nice to remove that restriction in TOSA and our lowerings: we should be able to do any arbitrary width lookup and use analysis to narrow the indices when possible.

I do want to make sure we aren't establishing that code containing gathers can't be vectorized, though. On avx2 and sve we have native instructions we can use but on older ISA's when those instructions aren't available vectorizing everything except the unrolled non-contiguous loads used to compose vectors is still viable and much better than vectorizing nothing at all (especially when fused with truly vectorizable work). It nets out to essentially just software-level implementations of the missing instructions; the need for non-contiguous loads was not introduced with ML and even though the instructions haven't existed people have been doing it effectively for decades at this point. Whether in the exp case the lookups are worth it is orthogonal - I mostly just don't want "we can't vectorize gathers" to be assumed true because today we can't generate them and going int to float feels like a giant hack we should make sure does not become the normal approach in any case where we happen to fuse a gather with other stuff :)

We should defer the decision about arithmetic smarts to later stages. We don't want to do it at TOSA -> Linalg stage. If we really want it in IREE, it can be done before fusion.

To add a concrete example to this, there are fusion opportunities for Softmax/Logistic and the ops that process their output, potentially bypassing most of the dequantizing or exp calculation.

In vision models, Softmax is usually followed by some thresholding e.g. discard results where value (probability) < 0.5. This logic can be fused with the Softmax so that only values > 0.5 (for quantized models, the quantized equivalent of 0.5) are included in the exp calculation. In practice, this eliminates the need to run exp on >80% of the values.

The logic would change from:
select(softmax(N) > quantized(0.5))
or
select(dequantize(softmax(N)) > 0.5)
or
select(softmax(dequantize(N)) > 0.5)

to:

softmax(select(N > quantized(0.5))) where the output of select() is much lower than N.

We should defer the decision about arithmetic smarts to later stages. We don't want to do it at TOSA -> Linalg stage. If we really want it in IREE, it can be done before fusion.

To add a concrete example to this, there are fusion opportunities for Softmax/Logistic and the ops that process their output, potentially bypassing most of the dequantizing or exp calculation.

In vision models, Softmax is usually followed by some thresholding e.g. discard results where value (probability) < 0.5. This logic can be fused with the Softmax so that only values > 0.5 (for quantized models, the quantized equivalent of 0.5) are included in the exp calculation. In practice, this eliminates the need to run exp on >80% of the values.

The logic would change from: select(softmax(N) > quantized(0.5)) or select(dequantize(softmax(N)) > 0.5) or select(softmax(dequantize(N)) > 0.5)

to:

softmax(select(N > quantized(0.5))) where the output of select() is much lower than N.

This is great! This kind of cross op optimization is fair game (and probably worth doing!). I was talking more about specific handling of exp using lookups..

RE 16-bit tosa lookup: that's a restriction of older/smaller ISAs that could only handle 16-bit element lookups (neon,

As far as I know, NEON's table-lookup instruction (TBL) only supports loading bytes, not 16-bit elements.

There's a much deeper reason though why tosa.table is not vectorizable on NEON: NEON TBL requires the table to fit in at most 64 bytes in registers. tosa.table's are bigger than that and there's no reasonably efficient way to implement them as multiple TBL instructions.

So when the target is NEON, all tosa.table's will be simple scalar code. Even the 8bit-element ones.

Given that the 8bit/16bit restriction of tosa.table not actually mirroring ARM NEON features, one could think: why not also support 32bit elements.

But looking at the TOSA spect for TABLE it's getting clearer what's going to be the real friction if we try to generalize it. It's that TABLE is not always the simple table-lookup operation that I thought it was. It is when the in_t is int8 but is something quite different when the in_t is int16. That makes it nontrivial to predict what int32-elements extension if any would be palatable.

That complication with the int16-indexed case also helps explain why the disasm we're seeing here (main_dispatch_59) is so inefficient. Before, I could not explain all the bit-shifts that I was seeing. Now I get that each 16-bit index needs first to be right-shifted by 7 bits to obtain the actual array index ; and the low 7 bits are preserved (look for and 0x7f in the disasm) to act as the interpolation parameter, and the actual interpolation involves a multiplication... now the disasm all makes sense.

And so now that I understand that, my own take-away from this is that anything that involves int16-indexed 513-sized lookup tables will be slow.

Gather-load instructions won't fix that part. Gather-loads help the memory-access part but the problem with int16-indexed TABLE is all the additional arithmetic that it requires.

To add a concrete example to this, there are fusion opportunities for Softmax/Logistic and the ops that process their output, potentially bypassing most of the dequantizing or exp calculation.

Interesting! Kind of tips the scales in favor of dequantizing softmax. Tempting as the routes to making quantized softmax less slow are not looking as easy as I hoped (previous comment).

Sorry, I am going to resurface the question, why do we need to implement a table look up at the TOSA level? All the effort it takes to implement the table lookup at TOSA level can be done at arith/math dialect level. (We could also do a table lookup there if we need to, and not be hampered by TOSA spec...)

Sorry @MaheshRavishankar - I'm not 100% sure to correctly grasp your suggestions, let's discuss this in our 1:1 today.

Summary of chat with Mahesh:

  • Why not preserve the quantized exponentials as just math.exp so the decision of whether and how to use lookup tables can be deferred to a later lowering?
    • Can't be quite just that because math.exp on integers would be just computing integral powers of e - a quantized exponential is computing fractional powers of e, with the mapping of int8 values to fractional exponents given by the input's quantization parameters (zero_point and scale).
      • OK then why not preserve exponentials as a few arith/math ops (implementing that affine quantization function and then the math.exp on that?)
        • Yes, we could do that.
  • Why does the tflite-to-tosa lowering also implement the division by the denominator as a specific polynomial approximation, why not preserve that as a div op to allow such details to be taken care of later?
    • Representing the division as just a div would lose crucial range information on the numerator and denominator that is what allows using certain polynomial approximations. In order to allow this to remain a div op, a new attribute would need to be introduced on the div op to allow encoding that range information.
  • Ideally a new tosa.softmax op would be added, removing the need to resolve any such detail at tosa import.
    • Then TosaToLinalg could lower tosa.softmax to 3 linalg.generics, implementing the 3 passes of naive Softmax.
      • The 2-pass Softmax algorithm is safe to pass over here because (1) the compiler approach here means that some of the 3 passes may get fused, so minimizing the number of passes is not as interesting as it was in the library approach; and (2) even in the library approach, 2-pass softmax is only a win in typically very large scenarios thrashing all levels of data caches.
      • That means that baking in 3-pass Softmax is not significant loss of generality
    • Arithmetic details would be confined to the basic-blocks inside the linalg.generics, not surfaced in the linalg.generic's themselves, so the arithmetic details (how to implement exponentials, how to implement the division) would be left to later stages. Just, the arith.div in the 3rd linalg.generic's basic-block would need that new attribute discussed above, for later lowerings to be able to pick the best implementation for the given operand ranges.
  • This is adding up to a LOT of work to fix quantized softmax without digging further the premature-lowering hole we're in, and meanwhile, just dequantizing softmax looks like the better short term effort compromise.
  • Mostly recording this for future discussion of TOSA spec.

Regarding preserving the div with range info, the way this is sometimes done is to use a quantized type for the element type. Not saying this is a great idea to preserve deeply into the compiler (we really want such types normalized early). Imo, if you want to go down this path, having dedicated arithmetic ops for such things.

Regarding a new tosa.softmax -- I've said it before: entire wars have been fought in compilers that prematurely lower fixed-point ratio aggregates like softmax to primitives. I don't have a principled reason to say so, but just based on experience and instinct, I think that if you care about fixed point softmax, then your opset should have a softmax op and let the compiler make decisions about it as late as possible.

If I were to attach a principle to this, I would say that fixed point softmax is more of an algorithm family than a composition. For compositions, we break them down to primitives. For algorithms, we preserve them as-is. You can then have an optional pass which lowers your algorithm ops to a default implementation and preserve generality. I would +1 an enhancement to TOSA like this.

Very interesting conversation here! We've been actively following the developing conversation here since it offers valuable low level implemeentation feed
back around TOSA. While the feedback mechanism into spec has its own processes (@stellaraccident can guide as she has contributed), I'll attempt to offer some broad thoughts here that summarize some of our experiences. A suggestion from @stellaraccident was to summarize minutes of such decision-making for context into current and future conversation. We've also received this from internal stakeholders and are working on it. Meanwhile offering some insight into prior thinking:

  • "Should tosa.xyz be an op ?"
    We tried to summarize this in the selection principles (https://www.mlplatform.org/tosa/tosa_spec.html#_operator_selection). Yes, they could be applied to contend against some operators recursively. Why is there a tosa.matmul when it could be a bunch of tosa.dot_mul ? And why dot_muls when it could be tosa.mul + tosa.reduce_sum (+ transpose) ? The backstop here is the 'can we avoid hardware duplication' principle view that enabled rationalization - custom hardware often implements a matmul array which can specialize matvecs, dots...
  • "Should tosa.softmax be an op ?"
    We've had this discussion as well (circa late 2020) - it was hard to view softmax as a single construct like matmul above. It could be a kernel running on some programmable activation unit of custom SI. It could be a naive scalar construct that doesn't SIMD-ize well as we saw here and in the great f2f with @bjacob and @rsuderman recently. It could entirely but a tosa.table. From recollection, we saw softmax and batch_norm as similar issues - compound ops with multiple moving parts to address - the latter with the additional connection to training mode.
  • "Is a single legalization always ok ?"
    The normative basis of the existing implementation was alignment to TFLite as we mentioned during our call with @bjacob. Internally we've debated whether different architectures warrant different conversions. In fact they're already present - fp and quantized integers are handled differently in the existing code. But that's only the spec-driven granularity. Some sub-questions here that were considered:
    a) Do architectures warrant different conversion sequences ? For this issue, it might be an option. However this means moving arch level choices into TOSA while simultaneously claiming arch indepdendence. That may be ok - it's just one of several paths. But it weighs upon the compiler infrastructure and presents development and maintainability problems.
    b) Does the existence of multiple architectural forms of a compound op justify the insertion of that op (a notional tosa.softmax here) into spec ? We didn't see that as a strong enough principle at the time, as the argument could be broadly applied without a reasonable backstop principle accompanying it.

Hope this helps in terms of insight into choices here. We'd be happy to continue this conversation - as we've stated, IREE is at the vanguard of TOSA codegen on CPU and offers tremendously useful feedback!

Dequantization of softmax has landed (#9337) and the expected performance improvement has materialized on the dashboard (#9337 (comment)).

@sjarus thanks for the insights, i'll follow up with you separately on these ramifications. With #9337 we have merely "unblocked" ourselves in that these topics don't block IREE getting good e2e performance anymore, but they're still important topics in their own right and eventually they'll unlock getting softmax to perform even better than a decent dequantized impl.

Sounds great, @bjacob!