[CPU][ARM] FP16 slower than FP32 on Android
mariecwhite opened this issue · comments
What happened?
For both default and DT+UK flags, seeing FP16 is slower than FP32 on a Pixel 8 Pro:
Default flags:
name | threads | FP32 latency (ms) | FP16 latency (ms) |
---|---|---|---|
BERT_BASE_FP32_TFLITE_I32_SEQLEN8 | 1 | 98 | 121 |
BERT_BASE_FP32_TFLITE_I32_SEQLEN32 | 1 | 157.0 | 198.0 |
BERT_BASE_FP32_TFLITE_I32_SEQLEN64 | 1 | 268.0 | 305.0 |
DT+UK Flags:
name | threads | FP32 latency (ms) | FP16 latency (ms) |
---|---|---|---|
BERT_BASE_FP32_TFLITE_I32_SEQLEN8 | 1 | 20.1 | 76.3 |
BERT_BASE_FP32_TFLITE_I32_SEQLEN32 | 1 | 81.8 | 132 |
BERT_BASE_FP32_TFLITE_I32_SEQLEN64 | 1 | 162 | 210 |
Steps to reproduce your issue
ARM_CPU_FEATURES="+v9a,+fullfp16,fp-armv8,+neon,+aes,+sha2,+crc,+lse,+rdm,+complxnum,+rcpc,+sha3,+sm4,+dotprod,+fp16fml,+dit,+flagm,+ssbs,+sb,+altnzcv,+fptoint,+bf16,+i8mm,+bti"
iree-compile tosa.mlirbc \
--iree-hal-target-backends="llvm-cpu" \
--iree-input-type="tosa" \
--iree-llvmcpu-link-embedded=false \
--iree-input-demote-f64-to-f32=false \
--iree-input-demote-i64-to-i32=false \
--iree-llvmcpu-target-cpu-features="${ARM_CPU_FEATURES}" \
--iree-llvmcpu-target-triple="aarch64-none-linux-android34" \
# Optional.
--iree-opt-data-tiling \
--iree-llvmcpu-enable-microkernels \
-o module.vmfb
- Run on device:
iree-benchmark-module --module=module.vmfb --task_topology_cpu_ids=0 --device=local-task --function=main --input=1x8xi32=0 --input=1x8xi32=0
What component(s) does this issue relate to?
Compiler
Version information
Additional context
No response
I don't see this pattern when using the StableHLO path i.e. JAX -> FP32 to FP16 -> StableHLO MLIR -> IREE
vs JAX -> TF -> TFLite -> FP32 to FP16 -> TOSA MLIR -> IREE
.
Default flags:
name | threads | FP32 latency (ms) | FP16 latency (ms) |
---|---|---|---|
BERT_BASE_FP32_JAX_I32_SEQLEN8 | 1 | 92.8 | 51.0 |
BERT_BASE_FP32_JAX_I32_SEQLEN32 | 1 | 161.0 | 111.0 |
BERT_BASE_FP32_JAX_I32_SEQLEN64 | 1 | 270.0 | 201.0 |
I don't see this pattern when using the StableHLO path i.e.
JAX -> FP32 to FP16 -> StableHLO MLIR -> IREE
vsJAX -> TF -> TFLite -> FP32 to FP16 -> TOSA MLIR -> IREE
.Default flags:
name threads FP32 latency (ms) FP16 latency (ms)
BERT_BASE_FP32_JAX_I32_SEQLEN8 1 92.8 51.0
BERT_BASE_FP32_JAX_I32_SEQLEN32 1 161.0 111.0
BERT_BASE_FP32_JAX_I32_SEQLEN64 1 270.0 201.0
Hi @mariecwhite,
Just to make sure I understand this statement correctly: what you're saying is for the same input model if you "export" it through stable HLO we get the expected perfs, but if we export it through TFLite -> TOSA
we get the perf discrepancy.
Did I get this right?
Yes that's right. Exporting through StableHLO looks fine but not TOSA.
@mariecwhite Do you have the FP16 IR handy for the StableHLO input?
I want to double check that we're not missing something obvious that may be happening on this path and not on the tosa path.
TL;DR As far as I can tell, the overhead for the tosa input is indeed expected and there's not much we can do about it.
To get the perf improvements, we'll need to do the math on the smaller type.
E.g.:
// f32 version
%a = tosa.add %b, %c : (tensor<1x8x768xf32>, tensor<1x8x768xf32>) -> tensor<1x8x768xf32>
// f16 version
%a = tosa.add %b, %c : (tensor<1x8x768xf16>, tensor<1x8x768xf16>) -> tensor<1x8x768xf16>
I.e., same operation different type.
Instead what we get is:
// f16 version
%b32 = tosa.cast %b : (tensor<1x8x768xf16>) -> tensor<1x8x768xf32>
%c32 = tosa.cast %c : (tensor<1x8x768xf16>) -> tensor<1x8x768xf32>
%a32 = tosa.add %b32, %c32 : (tensor<1x8x768xf32>, tensor<1x8x768xf32>) -> tensor<1x8x768xf32>
I.e., we upcast all the math to the more precise math. So effectively we're running the same math as the f32 version + we pay the overhead of upcasting all the types.
At this point, an optimization cannot safely undo that (by pushing the cast down the def-use chain for instance) because that could have huge numerical implications.
I guess we could add an option to force that, but that sounds like a recipe for disaster.
Is there a way we could fix the front-end to stick to the math on f16 operations?
@mariecwhite, do you know who could look at this from the frontend perspective?
The JAX FP16 version is here: https://storage.googleapis.com/iree-model-artifacts/jax/jax_models_0.4.20_1699872537/BERT_BASE_FP16_JAX_I32_SEQLEN8/stablehlo.mlirbc
@rsuderman @NatashaKnk what are your thoughts on supporting these kinds of TFLite graphs? This does seem out of scope. It appears TFLite itself doesn't do any special handling for these since it tends to be slower than IREE:
Prefix Length | IREE | TFLite |
---|---|---|
8 | 42.27 | 17.89 |
32 | 76.78 | 98.91 |
64 | 88.51 | 148.22 |
128 | 145.27 | 223.03 |
256 | 274.58 | 361.16 |
512 | 458.20 | 687.50 |
Thanks for the StableHLO input @mariecwhite!
I could confirm that this path does more of its computations directly on the f16 as opposed to always promoting to f32.
I let @rsuderman and @NatashaKnk comment on the frontend aspect, but from the backend perspective, we do what we have been told, i.e., there is nothing to fix here.
Assigning @NatashaKnk to comment/decide on what to do on the front-end side.
Sorry, I'm a bit backlogged at the moment so I didn't get the chance to look at this yet.
I'm curious why these casts happen in the first place, from naively looking at the tosa spec fp16 should be supported. Is there any obvious reason this is happening that I'm missing?
Otherwise I can investigate this deeper sometime next week.
Sorry, I'm a bit backlogged at the moment so I didn't get the chance to look at this yet. I'm curious why these casts happen in the first place, from naively looking at the tosa spec fp16 should be supported. Is there any obvious reason this is happening that I'm missing?
I don't know, that's a question from someone from the front-end and I thought you could shine some light on that to be fair.
I did a bit more investigation here and it is a deliberate design decision to include the FP16 to FP32 dequantization ops in the TFLite flatbuffer. Some additional metadata is included to mark whether it is compatible with FP16. If it is FP16 compatible, XNNPack will remove those dequantization ops.
Let's mark this out of scope.