iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[CPU][ARM] FP16 slower than FP32 on Android

mariecwhite opened this issue · comments

What happened?

For both default and DT+UK flags, seeing FP16 is slower than FP32 on a Pixel 8 Pro:

Default flags:

name threads FP32 latency (ms) FP16 latency (ms)
BERT_BASE_FP32_TFLITE_I32_SEQLEN8 1 98 121
BERT_BASE_FP32_TFLITE_I32_SEQLEN32 1 157.0 198.0
BERT_BASE_FP32_TFLITE_I32_SEQLEN64 1 268.0 305.0

DT+UK Flags:

name threads FP32 latency (ms) FP16 latency (ms)
BERT_BASE_FP32_TFLITE_I32_SEQLEN8 1 20.1 76.3
BERT_BASE_FP32_TFLITE_I32_SEQLEN32 1 81.8 132
BERT_BASE_FP32_TFLITE_I32_SEQLEN64 1 162 210

Steps to reproduce your issue

  1. Download MLIR Artifacts: FP32, FP16

  2. Compile both:

ARM_CPU_FEATURES="+v9a,+fullfp16,fp-armv8,+neon,+aes,+sha2,+crc,+lse,+rdm,+complxnum,+rcpc,+sha3,+sm4,+dotprod,+fp16fml,+dit,+flagm,+ssbs,+sb,+altnzcv,+fptoint,+bf16,+i8mm,+bti"

iree-compile tosa.mlirbc \
    --iree-hal-target-backends="llvm-cpu" \
    --iree-input-type="tosa" \
    --iree-llvmcpu-link-embedded=false \
    --iree-input-demote-f64-to-f32=false \
    --iree-input-demote-i64-to-i32=false \
    --iree-llvmcpu-target-cpu-features="${ARM_CPU_FEATURES}" \
    --iree-llvmcpu-target-triple="aarch64-none-linux-android34" \
    # Optional.
    --iree-opt-data-tiling \
    --iree-llvmcpu-enable-microkernels \
    -o module.vmfb
  1. Run on device:
iree-benchmark-module --module=module.vmfb --task_topology_cpu_ids=0 --device=local-task --function=main --input=1x8xi32=0 --input=1x8xi32=0

What component(s) does this issue relate to?

Compiler

Version information

675aafb

Additional context

No response

I don't see this pattern when using the StableHLO path i.e. JAX -> FP32 to FP16 -> StableHLO MLIR -> IREE vs JAX -> TF -> TFLite -> FP32 to FP16 -> TOSA MLIR -> IREE.

Default flags:

name threads FP32 latency (ms) FP16 latency (ms)
BERT_BASE_FP32_JAX_I32_SEQLEN8 1 92.8 51.0
BERT_BASE_FP32_JAX_I32_SEQLEN32 1 161.0 111.0
BERT_BASE_FP32_JAX_I32_SEQLEN64 1 270.0 201.0

The TFLite graph has a bunch of Dequantize ops and I'm not sure if they should be compiled/optimized away or the graph is bad to begin with. The JAX graph is much cleaner.

TFLite:
image

JAX:
image

I don't see this pattern when using the StableHLO path i.e. JAX -> FP32 to FP16 -> StableHLO MLIR -> IREE vs JAX -> TF -> TFLite -> FP32 to FP16 -> TOSA MLIR -> IREE.

Default flags:

name threads FP32 latency (ms) FP16 latency (ms)
BERT_BASE_FP32_JAX_I32_SEQLEN8 1 92.8 51.0
BERT_BASE_FP32_JAX_I32_SEQLEN32 1 161.0 111.0
BERT_BASE_FP32_JAX_I32_SEQLEN64 1 270.0 201.0

Hi @mariecwhite,

Just to make sure I understand this statement correctly: what you're saying is for the same input model if you "export" it through stable HLO we get the expected perfs, but if we export it through TFLite -> TOSA we get the perf discrepancy.

Did I get this right?

Yes that's right. Exporting through StableHLO looks fine but not TOSA.

@mariecwhite Do you have the FP16 IR handy for the StableHLO input?
I want to double check that we're not missing something obvious that may be happening on this path and not on the tosa path.

TL;DR As far as I can tell, the overhead for the tosa input is indeed expected and there's not much we can do about it.

To get the perf improvements, we'll need to do the math on the smaller type.
E.g.:

// f32 version
%a = tosa.add %b, %c : (tensor<1x8x768xf32>, tensor<1x8x768xf32>) -> tensor<1x8x768xf32>

// f16 version
%a = tosa.add %b, %c : (tensor<1x8x768xf16>, tensor<1x8x768xf16>) -> tensor<1x8x768xf16>

I.e., same operation different type.

Instead what we get is:

// f16 version
%b32 = tosa.cast %b : (tensor<1x8x768xf16>) -> tensor<1x8x768xf32>
%c32 = tosa.cast %c : (tensor<1x8x768xf16>) -> tensor<1x8x768xf32>
%a32 = tosa.add %b32, %c32 : (tensor<1x8x768xf32>, tensor<1x8x768xf32>) -> tensor<1x8x768xf32>

I.e., we upcast all the math to the more precise math. So effectively we're running the same math as the f32 version + we pay the overhead of upcasting all the types.

At this point, an optimization cannot safely undo that (by pushing the cast down the def-use chain for instance) because that could have huge numerical implications.

I guess we could add an option to force that, but that sounds like a recipe for disaster.

Is there a way we could fix the front-end to stick to the math on f16 operations?

@mariecwhite, do you know who could look at this from the frontend perspective?

The JAX FP16 version is here: https://storage.googleapis.com/iree-model-artifacts/jax/jax_models_0.4.20_1699872537/BERT_BASE_FP16_JAX_I32_SEQLEN8/stablehlo.mlirbc

@rsuderman @NatashaKnk what are your thoughts on supporting these kinds of TFLite graphs? This does seem out of scope. It appears TFLite itself doesn't do any special handling for these since it tends to be slower than IREE:

Prefix Length IREE TFLite
8 42.27 17.89
32 76.78 98.91
64 88.51 148.22
128 145.27 223.03
256 274.58 361.16
512 458.20 687.50

Thanks for the StableHLO input @mariecwhite!

I could confirm that this path does more of its computations directly on the f16 as opposed to always promoting to f32.

I let @rsuderman and @NatashaKnk comment on the frontend aspect, but from the backend perspective, we do what we have been told, i.e., there is nothing to fix here.

Assigning @NatashaKnk to comment/decide on what to do on the front-end side.

Sorry, I'm a bit backlogged at the moment so I didn't get the chance to look at this yet.
I'm curious why these casts happen in the first place, from naively looking at the tosa spec fp16 should be supported. Is there any obvious reason this is happening that I'm missing?

Otherwise I can investigate this deeper sometime next week.

Sorry, I'm a bit backlogged at the moment so I didn't get the chance to look at this yet. I'm curious why these casts happen in the first place, from naively looking at the tosa spec fp16 should be supported. Is there any obvious reason this is happening that I'm missing?

I don't know, that's a question from someone from the front-end and I thought you could shine some light on that to be fair.

I did a bit more investigation here and it is a deliberate design decision to include the FP16 to FP32 dequantization ops in the TFLite flatbuffer. Some additional metadata is included to mark whether it is compatible with FP16. If it is FP16 compatible, XNNPack will remove those dequantization ops.

Let's mark this out of scope.