iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[SDXL][GFX942] Numerical randomness and discrepancy with pytorch

Max191 opened this issue · comments

There are numerical errors in unet of 1.0-3.0 max error compared with pytorch on 71a9945. Also, the results of iree-run-module are different with each invocation.

Steps to Reproduce

Checkout 71a9945

Compile unet (compile-unet.sh from inside https://github.com/monorimet/sdxl-scripts):

iree-compile $PWD/base_ir/stable_diffusion_xl_base_1_0_64_1024x1024_fp16_unet.mlir \
    --iree-hal-target-backends=rocm \
    --iree-rocm-target-chip=gfx942 \
    --iree-rocm-bc-dir=$PWD/bitcode-2024-03-07 \
    --iree-global-opt-propagate-transposes=true \
    --iree-opt-outer-dim-concat=true \
    --iree-opt-data-tiling=false \
    --iree-opt-const-eval=false \
    --iree-codegen-llvmgpu-use-vector-distribution \
    --iree-llvmgpu-enable-prefetch \
    --iree-codegen-gpu-native-math-precision=true \
    --iree-rocm-waves-per-eu=2 \
    --iree-flow-enable-aggressive-fusion \
    --iree-global-opt-enable-fuse-horizontal-contractions=true \
    --iree-opt-aggressively-propagate-transposes=true \
    --iree-execution-model=async-external \
    --iree-hal-dump-executable-configurations-to=configurations/unet \
    --iree-hal-dump-executable-sources-to=sources/unet \
    --iree-hal-dump-executable-binaries-to=binaries/unet \
    --iree-hal-dump-executable-benchmarks-to=benchmarks/unet \
    --iree-preprocessing-pass-pipeline="builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))" \
    --iree-codegen-transform-dialect-library=$PWD/specs/attention_and_matmul_spec.mlir \
    -o $PWD/tmp/unet.vmfb

Run unet with iree-run-module:

iree-run-module --device=rocm://$1 --device_allocator=caching --module=tmp/unet.vmfb --parameters=model=/path_to/scheduled_unet.irpa --function=main --input=1x4x128x128xf16=0 --input=1xi64=0 --input=2x64x2048xf16=0 --input=2x1280xf16=0 --input=2x6xf16=0 --input=1xf16=0 > unet-results-run-module.txt

Running multiple times yields different results each time.

Compare with Pytorch

Setup shark turbine on https://github.com/nod-ai/SHARK-Turbine/tree/ean-sdxl-fixes

Run unet runner:

python3 SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/unet_runner.py \
    --compare_vs_torch --precision=fp16 --device=rocm \
    --external_weight_path=/path_to/scheduled_unet.irpa \
    --max_length=64 --vmfb_path=tmp/unet.vmfb