iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

(gfx1103/Windows) Numerics issues on HIP driver for SDXL Unet

monorimet opened this issue · comments

What happened?

The same .vmfb gives different results on ROCM and HIP hal drivers. Caching allocator is being used on both, but this doesn't seem to make a difference if disabled.

Good numerics: ROCM with inlined weights gives correct output.

Bad numerics # 1: ROCM with external weights gives all zeroes output

Bad numerics # 2: HIP with inlined weights gives wrong numbers.

Bad numerics # 3: HIP with external weights gives wrong numbers.

I am filing this issue specifically for this target and IR because other targets and models do not reproduce the same success/failure cases. (see #17033)

The only reason I am including ROCM HAL results is because they contain the only success mode. We should focus on fixing HIP hal issues.

Full log output using turbine-models scripts -- I will provide iree CLI reproducers as well, but these are using fixed random inputs:

(shark.venv) PS C:\Users\eagarvey\SHARK\numerics_debug_hip> python C:\Users\eagarvey\SHARK\SHARK-Turbine\models\turbine_models\custom_models\sdxl_inference\unet_runner.py --precision=fp16 --device=hip --external_weights=safetensors --num_inference_steps=1 --scheduler_id=EulerDiscrete --compile_to=vmfb --iree_target_triple=gfx1103 --vmfb_path=stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp16_unet_rocm.vmfb                                                  
TURBINE OUTPUT: [[[[ 0.002733  0.002733  0.002733 ...  0.002733  0.002733  0.002733]
   [ 0.002733  0.002733  0.002733 ...  0.002733  0.002733  0.002733]
   [ 0.002733  0.002733  0.002733 ...  0.002733  0.002733  0.002733]
   ...
   [ 0.002733  0.002733  0.002733 ...  0.002733  0.002733  0.002733]
   [ 0.002733  0.002733  0.002733 ...  0.002733  0.002733  0.002733]
   [ 0.002733  0.002733  0.002733 ...  0.002733  0.002733  0.002733]]

  [[-0.001411 -0.001411 -0.001411 ... -0.001411 -0.001411 -0.001411]
   [-0.001411 -0.001411 -0.001411 ... -0.001411 -0.001411 -0.001411]
   [-0.001411 -0.001411 -0.001411 ... -0.001411 -0.001411 -0.001411]
   ...
   [-0.001411 -0.001411 -0.001411 ... -0.001411 -0.001411 -0.001411]
   [-0.001411 -0.001411 -0.001411 ... -0.001411 -0.001411 -0.001411]
   [-0.001411 -0.001411 -0.001411 ... -0.001411 -0.001411 -0.001411]]

  [[ 0.001999  0.001999  0.001999 ...  0.001999  0.001999  0.001999]
   [ 0.001999  0.001999  0.001999 ...  0.001999  0.001999  0.001999]
   [ 0.001999  0.001999  0.001999 ...  0.001999  0.001999  0.001999]
   ...
   [ 0.001999  0.001999  0.001999 ...  0.001999  0.001999  0.001999]
   [ 0.001999  0.001999  0.001999 ...  0.001999  0.001999  0.001999]
   [ 0.001999  0.001999  0.001999 ...  0.001999  0.001999  0.001999]]

  [[-0.00314  -0.00314  -0.00314  ... -0.00314  -0.00314  -0.00314 ]
   [-0.00314  -0.00314  -0.00314  ... -0.00314  -0.00314  -0.00314 ]
   [-0.00314  -0.00314  -0.00314  ... -0.00314  -0.00314  -0.00314 ]
   ...
   [-0.00314  -0.00314  -0.00314  ... -0.00314  -0.00314  -0.00314 ]
   [-0.00314  -0.00314  -0.00314  ... -0.00314  -0.00314  -0.00314 ]
   [-0.00314  -0.00314  -0.00314  ... -0.00314  -0.00314  -0.00314 ]]]] (1, 4, 128, 128) float16
(shark.venv) PS C:\Users\eagarvey\SHARK\numerics_debug_hip> python C:\Users\eagarvey\SHARK\SHARK-Turbine\models\turbine_models\custom_models\sdxl_inference\unet_runner.py --precision=fp16 --device=rocm --external_weights=safetensors --num_inference_steps=1 --scheduler_id=EulerDiscrete --compile_to=vmfb --iree_target_triple=gfx1103 --vmfb_path=stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp16_unet_rocm.vmfb
TURBINE OUTPUT: [[[[-0.6895  -0.3262   0.9443  ... -0.9014   1.218    0.2766 ]
   [-0.1412   0.2     -0.4927  ...  0.627   -1.163    0.4128 ]
   [-0.01245  0.633    0.1671  ... -0.05722 -0.0687  -0.10736]
   ...
   [-0.1569  -0.1216   0.325   ...  0.3892   0.7476   0.06064]
   [-0.4585   0.2944  -0.9595  ...  0.797    0.2452   0.1302 ]
   [-0.02382  1.318   -0.2832  ... -0.4692   1.057   -1.516  ]]

  [[ 0.03044 -0.4458   0.836   ...  0.2281   0.4502  -0.0377 ]
   [ 0.0486  -0.4158  -0.2251  ... -0.4724   0.4004   1.592  ]
   [ 0.92    -0.573    0.2286  ...  0.81     0.01987  0.398  ]
   ...
   [-0.5947  -1.238   -0.05618 ...  0.1353   0.0868  -0.2744 ]
   [-0.1533  -0.291    0.1362  ...  0.1338   0.1406   0.9385 ]
   [ 0.03732  1.064    1.513   ...  0.3914  -0.6694   0.699  ]]

  [[ 0.881   -0.3994   0.763   ...  0.339    0.7397  -0.295  ]
   [ 0.615    0.203   -0.7407  ... -0.1326  -0.0328   0.147  ]
   [ 0.733   -0.1461   0.1094  ...  0.44    -1.463    0.8037 ]
   ...
   [ 0.2075   0.4565   0.7773  ... -0.3655   0.1267  -0.02698]
   [ 0.4185   0.218   -0.297   ...  0.478   -1.067    1.498  ]
   [ 1.348   -0.2026  -0.1068  ...  0.044   -0.05292 -0.163  ]]

  [[ 0.7183   0.2141  -0.743   ...  1.86    -2.348   -0.1821 ]
   [-0.165   -0.10284 -0.02016 ... -0.3496   0.9595   1.24   ]
   [ 0.1755  -0.0325   1.589   ...  0.3892  -1.476    0.8857 ]
   ...
   [ 0.2554   0.6816   0.2898  ... -0.2029   0.3306   0.3394 ]
   [-1.031   -0.042    0.5566  ...  0.4116   1.478    0.01047]
   [ 0.575    1.234   -1.045   ... -0.8857   0.8745  -0.1274 ]]]] (1, 4, 128, 128) float16    
(shark.venv) PS C:\Users\eagarvey\SHARK\numerics_debug_hip> python C:\Users\eagarvey\SHARK\SHARK-Turbine\models\turbine_models\custom_models\sdxl_inference\unet.py --precision=fp16 --device=rocm --external_weights=safetensors --num_inference_steps=1 --scheduler_id=EulerDiscrete --compile_to=vmfb --iree_target_triple=gfx1103                                                   
C:\Users\eagarvey\SHARK\SHARK\shark.venv\Lib\site-packages\diffusers\utils\outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
C:\Users\eagarvey\SHARK\SHARK\shark.venv\Lib\site-packages\diffusers\utils\outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
C:\Users\eagarvey\SHARK\SHARK\shark.venv\Lib\site-packages\huggingface_hub\file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Compiling to rocm with flags: ['--iree-hal-target-backends=rocm', '--iree-rocm-target-chip=gfx1103', '--iree-vm-bytecode-module-output-format=flatbuffer-binary', '--iree-global-opt-propagate-transposes=true', '--iree-opt-outer-dim-concat=true', '--iree-vm-target-truncate-unsupported-floats', '--iree-llvmgpu-enable-prefetch=true', '--iree-opt-data-tiling=false', '--iree-opt-const-eval=false', '--iree-opt-aggressively-propagate-transposes=true', '--iree-flow-enable-aggressive-fusion', '--iree-global-opt-enable-fuse-horizontal-contractions=true', '--iree-codegen-gpu-native-math-precision=true', '--iree-codegen-llvmgpu-use-vector-distribution=true', '--iree-codegen-llvmgpu-enable-transform-dialect-jit=false', '--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))', '--iree-codegen-transform-dialect-library=attention_and_matmul_spec_wmma.mlir']
Saved to stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp16_unet_rocm.mlir
Saved to stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp16_unet_rocm.vmfb
(shark.venv) PS C:\Users\eagarvey\SHARK\numerics_debug_hip> python C:\Users\eagarvey\SHARK\SHARK-Turbine\models\turbine_models\custom_models\sdxl_inference\unet_runner.py --precision=fp16 --device=hip --external_weights=safetensors --num_inference_steps=1 --scheduler_id=EulerDiscrete --compile_to=vmfb --iree_target_triple=gfx1103 --vmfb_path=stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp16_unet_rocm.vmfb --external_weight_path=scheduled_unet.safetensors
TURBINE OUTPUT: [[[[0.334    0.273    0.807    ... 0.2734   0.649    0.7837  ]
   [0.0825   0.8296   0.7188   ... 0.659    0.229    0.1401  ]
   [0.2563   0.9097   0.84     ... 0.166    0.2485   0.251   ]
   ...
   [0.6636   0.4502   0.718    ... 0.986    0.9766   0.5557  ]
   [0.2847   0.2104   0.1401   ... 0.68     0.5996   0.5386  ]
   [0.5967   0.753    0.3506   ... 0.3213   0.743    0.0962  ]]

  [[0.853    0.274    0.7505   ... 0.463    0.627    0.7515  ]
   [0.1572   0.626    0.6274   ... 0.845    0.9517   0.774   ]
   [0.555    0.3218   0.3975   ... 0.8486   0.3613   0.393   ]
   ...
   [0.169    0.05664  0.477    ... 0.3018   0.3994   0.5474  ]
   [0.3916   0.2373   0.4214   ... 0.5234   0.836    0.9893  ]
   [0.6196   0.802    0.8315   ... 0.708    0.3242   0.898   ]]

  [[0.513    0.002441 0.2339   ... 0.0674   0.9834   0.6255  ]
   [0.6973   0.6025   0.3115   ... 0.7783   0.9077   0.09814 ]
   [0.4058   0.8477   0.658    ... 0.2798   0.01807  0.04834 ]
   ...
   [0.6553   0.3813   0.8765   ... 0.536    0.4678   0.02344 ]
   [0.7695   0.764    0.8647   ... 0.313    0.007324 0.921   ]
   [0.8804   0.05664  0.4668   ... 0.2632   0.1309   0.1758  ]]

  [[0.5986   0.0801   0.31     ... 0.6123   0.1484   0.0947  ]
   [0.11816  0.9585   0.796    ... 0.169    0.69     0.1992  ]
   [0.1831   0.552    0.9834   ... 0.0801   0.02051  0.4287  ]
   ...
   [0.9175   0.773    0.9463   ... 0.9404   0.835    0.1401  ]
   [0.5796   0.1968   0.5195   ... 0.3672   0.9507   0.57    ]
   [0.2344   0.531    0.0547   ... 0.0703   0.9297   0.8394  ]]]] (1, 4, 128, 128) float16 

Steps to reproduce your issue

Artifacts:

MLIR (FP16): https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp16_unet_rocm.mlir
MLIR (FP32):
https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp32_unet_cpu.mlir
WMMA spec:
https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/attention_and_matmul_spec_wmma.mlir
MLIR (inlined, fp16):
https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp16_unet_inline.mlir

inputs:
https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/input1.npy
https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/input2.npy
https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/input3.npy
https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/input4.npy
https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/input5.npy
https://sharkpublic.blob.core.windows.net/sharkpublic/ean/hip_numerics/gfx1103_unet/numerics_debug_hip/input6.npy

Weights:
https://sharkpublic.blob.core.windows.net/sharkpublic/SDXL/SDXL_weights_fp16/scheduled_unet.irpa

Compile:

iree-compile --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx1103 --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-global-opt-propagate-transposes=true --iree-opt-outer-dim-concat=true --iree-vm-target-truncate-unsupported-floats --iree-llvmgpu-enable-prefetch=true --iree-opt-data-tiling=false --iree-opt-const-eval=false --iree-opt-aggressively-propagate-transposes=true --iree-flow-enable-aggressive-fusion --iree-global-opt-enable-fuse-horizontal-contractions=true --iree-codegen-gpu-native-math-precision=true --iree-codegen-llvmgpu-use-vector-distribution=true --iree-codegen-llvmgpu-enable-transform-dialect-jit=false --iree-preprocessing-pass-pipeline='builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics))' --iree-codegen-transform-dialect-library=attention_and_matmul_spec_wmma.mlir stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp16_unet_rocm.mlir -o stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp16_unet_rocm.vmfb

Run:

iree-run-module --module=stable_diffusion_xl_base_1_0_bs1_64_1024x1024_fp32_unet_cpu.vmfb --device_allocator=caching --parameters=model=scheduled_unet.irpa --input=@input1.npy --input=@input2.npy --input=@input3.npy --input=@input4.npy --input=@input5.npy --input=@input6.npy --device=hip

What component(s) does this issue relate to?

No response

Version information

IREE branch uses is shared/tresleches-united, but these issues historically reproduce on main branch, though all compile options here may not translate.

c66ae19 for exact commit.

Additional context

No response

I tried a few different configurations, and found a potentially useful runtime error when using inlined weights with SDXL:

Assertion failed: !!(iree_hal_resource_is(base_value, &iree_hal_rocm_buffer_vtable)), file C:\V\iree\experimental\rocm\rocm_buffer.c, line 25

That is what I get when running with HIP hal driver; using the ROCM driver works and gives same numerics as with externalized weights.

Could this be related to how we are using --iree-stream-resource-memory-model=unified by default? I am trying with this flag set to discrete now.

maybe you have it the other way around? the error you have says C:\V\iree\experimental\rocm\rocm_buffer.c which is ROCM, not HIP

that kind of error will happen if the driver is casting a buffer pointer instead of the iree_hal_allocated_buffer() result

I was a bit confused by this as well, but this was for sure run with HIP driver. Will validate with cli

if you're in a release LTO build it's possible the two functions are identical and got folded, but usually asserts and stuff prevent that - either way, good to test with a breakpoint or printf

OK, so if I switch from my local build, configured with:

cmake -GNinja -B ../iree-build --log-level=VERBOSE -DIREE_BUILD_PYTHON_BINDINGS=ON -DIREE_BUILD_COMPILER=ON -DPython3_EXECUTABLE=C:\\V\SHARK-Turbine\turb.env\Scripts\python.exe -DCMAKE_BUILD_TYPE=Release -DIREE_HAL_DRIVER_VULKAN=ON -DIREE_HAL_DRIVER_CUDA=OFF -DIREE_EXTERNAL_HAL_DRIVERS="rocm" -DIREE_ENABLE_CPUINFO=ON -DIREE_HAL_DRIVER_ROCM=ON -DIREE_ENABLE_LLD=ON -DIREE_ENABLE_RUNTIME_TRACING=OFF -DIREE_ENABLE_ASSERTIONS=ON -DIREE_ENABLE_SPLIT_DWARF=ON

to a recent pip install of iree-runtime, instead of giving an assertion on hip hal driver, it just starts completely freezing my system for minutes at a time. Will try with resnet again to see if it completes. This seems to happen with --iree-stream-resource-memory-model=unified and --iree-stream-resource-memory-model=discrete but I've only tried this with externalized weights. Will try with inlined.

Are the pip releases built with assertions disabled? It could explain this, if the driver is still casting the wrong pointer.