[SDXL][GFX942] Numerical randomness and discrepancy with pytorch
Max191 opened this issue · comments
There are numerical errors in unet of 1.0-3.0 max error compared with pytorch on 71a9945. Also, the results of iree-run-module
are different with each invocation.
Steps to Reproduce
Checkout 71a9945
Compile unet (compile-unet.sh
from inside https://github.com/monorimet/sdxl-scripts):
iree-compile $PWD/base_ir/stable_diffusion_xl_base_1_0_64_1024x1024_fp16_unet.mlir \
--iree-hal-target-backends=rocm \
--iree-rocm-target-chip=gfx942 \
--iree-rocm-bc-dir=$PWD/bitcode-2024-03-07 \
--iree-global-opt-propagate-transposes=true \
--iree-opt-outer-dim-concat=true \
--iree-opt-data-tiling=false \
--iree-opt-const-eval=false \
--iree-codegen-llvmgpu-use-vector-distribution \
--iree-llvmgpu-enable-prefetch \
--iree-codegen-gpu-native-math-precision=true \
--iree-rocm-waves-per-eu=2 \
--iree-flow-enable-aggressive-fusion \
--iree-global-opt-enable-fuse-horizontal-contractions=true \
--iree-opt-aggressively-propagate-transposes=true \
--iree-execution-model=async-external \
--iree-hal-dump-executable-configurations-to=configurations/unet \
--iree-hal-dump-executable-sources-to=sources/unet \
--iree-hal-dump-executable-binaries-to=binaries/unet \
--iree-hal-dump-executable-benchmarks-to=benchmarks/unet \
--iree-preprocessing-pass-pipeline="builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))" \
--iree-codegen-transform-dialect-library=$PWD/specs/attention_and_matmul_spec.mlir \
-o $PWD/tmp/unet.vmfb
Run unet with iree-run-module
:
iree-run-module --device=rocm://$1 --device_allocator=caching --module=tmp/unet.vmfb --parameters=model=/path_to/scheduled_unet.irpa --function=main --input=1x4x128x128xf16=0 --input=1xi64=0 --input=2x64x2048xf16=0 --input=2x1280xf16=0 --input=2x6xf16=0 --input=1xf16=0 > unet-results-run-module.txt
Running multiple times yields different results each time.
Compare with Pytorch
Setup shark turbine on https://github.com/nod-ai/SHARK-Turbine/tree/ean-sdxl-fixes
Run unet runner:
python3 SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/unet_runner.py \
--compare_vs_torch --precision=fp16 --device=rocm \
--external_weight_path=/path_to/scheduled_unet.irpa \
--max_length=64 --vmfb_path=tmp/unet.vmfb