iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Vulkan] Conv2d fails with VK_ERROR_OUT_OF_HOST_MEMORY on 7900XTX

harishanand95 opened this issue · comments

What happened?

Hi I'm getting iree failure on iree-benchmark-module with vulkan for a conv2d while it works fine for cpu. The IREE used is SRT version. Here is the example code:

# iree-compiler      20231104.572
# iree-runtime       20231104.572
# torch              2.2.0.dev20231101+cpu
# torch-mlir         20231102.1010
# nodai-SHARK        20231104.1013

import torch
import torch.nn as nn
import torch_mlir
# from shark.shark_importer import import_with_fx
# from shark.shark_inference import SharkInference


conv = nn.Conv2d(3, 4, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))
inp = torch.rand(1, 3, 32, 32)

# ts_g = import_with_fx(conv, [inp,], is_f16=True, f16_input_mask=[True,], mlir_type="torchscript")
linalg_on_tensors_mlir = torch_mlir.compile(conv, [inp.cpu(),], 
                                            output_type=torch_mlir.OutputType.LINALG_ON_TENSORS, 
                                            use_tracing=True, 
                                            ignore_traced_shapes=False)
with open("conv.mlir", "w") as f: f.write(linalg_on_tensors_mlir.operation.get_asm())

# WORKS
# iree-compile --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-truncate-unsupported-floats --iree-vm-target-index-bits=64 --iree-input-type=none --iree-stream-resource-max-allocation-size=4294967295 conv.mlir -o=conv.vmfb
# iree-benchmark-module --device=local-sync --module=conv.vmfb --function=forward --input=1x3x32x32xf32

# FAILS
# iree-compile --iree-hal-target-backends=vulkan-spirv --iree-stream-resource-index-bits=64 --iree-vm-target-truncate-unsupported-floats --iree-vm-target-index-bits=64 --iree-input-type=none --iree-vulkan-target-triple=rdna3-7900-windows --iree-stream-resource-max-allocation-size=4294967295 conv.mlir -o=conv.vmfb
# iree-benchmark-module --device=vulkan --module=conv.vmfb --function=forward --input=1x3x32x32xf32

# $ iree-benchmark-module --device=vulkan --module=conv.vmfb  --function=forward --input=1x3x32x32xf32
# 2023-11-07T10:10:01-05:00
# Running C:\Users\haranand\AppData\Local\hatch\env\virtual\whisper-shark\lCn2R_CD\whisper-shark\Lib\site-packages\iree\_runtime_libs\iree-benchmark-module
# Run on (24 X 4766.34 MHz CPU s)
# CPU Caches:
#   L1 Data 32 KiB (x12)
#   L1 Instruction 32 KiB (x12)
#   L2 Unified 1024 KiB (x12)
#   L3 Unified 32768 KiB (x2)
# C:\actions-runner\w\SRT\SRT\c\runtime\src\iree\hal\drivers\vulkan\native_allocator.cc:315: RESOURCE_EXHAUSTED; VK_ERROR_OUT_OF_HOST_MEMORY; vkAllocateMemory; while invoking native function hal.device.queue.alloca; while calling import; 
# [ 1]   native hal.device.queue.alloca:0 -
# [ 0] bytecode module.forward:372 conv.mlir:20:10
#       at conv.mlir:5:3

Any ideas on what I'm doing wrong here, thanks!

Steps to reproduce your issue

conv.zip

What component(s) does this issue relate to?

Runtime

Version information

iree-compiler 20231104.572
iree-runtime 20231104.572
torch 2.2.0.dev20231101+cpu
torch-mlir 20231102.1010
nodai-SHARK 20231104.1013

Additional context

This is used in a SHARK environment.

Just got around to this, but am having trouble reproducing this on Linux. Before I go set up on Windows, can you verify that the Linalg IR I have is the same as yours?

#map = affine_map<(d0, d1, d2, d3) -> (d1)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module attributes {torch.debug_module_name = "Conv2d"} {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @forward(%arg0: tensor<1x3x32x32xf32>) -> tensor<1x4x32x32xf32> {
    %cst = arith.constant dense<[[[[-0.194130942], [-0.002053231], [-0.231181473]], [[-3.029260e-02], [-0.00984880328], [0.17870906]], [[-0.056172967], [-4.834640e-02], [-0.269018412]]], [[[-0.0401979685], [0.310443789], [-0.0480653942]], [[-0.0493329465], [-0.118616432], [-0.215419739]], [[0.0980359315], [-8.195710e-02], [-0.229281545]]], [[[0.128511041], [-0.159413904], [-0.234531641]], [[-0.0471769273], [0.241814166], [-0.0965808629]], [[-0.311713904], [0.293067485], [1.79409981E-4]]], [[[-0.245140284], [-0.311403513], [-0.218301624]], [[0.139625669], [9.060490e-02], [0.118851066]], [[-0.224677771], [-0.112303421], [-0.0881195962]]]]> : tensor<4x3x3x1xf32>
    %cst_0 = arith.constant dense<[-0.311590075, 0.249434024, 0.314933091, -5.043140e-02]> : tensor<4xf32>
    %cst_1 = arith.constant 0.000000e+00 : f32 
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %padded = tensor.pad %arg0 low[0, 0, 1, 0] high[0, 0, 1, 0] {
    ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
      tensor.yield %cst_1 : f32 
    } : tensor<1x3x32x32xf32> to tensor<1x3x34x32xf32>
    %0 = tensor.empty() : tensor<1x4x32x32xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_0 : tensor<4xf32>) outs(%0 : tensor<1x4x32x32xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32 
    } -> tensor<1x4x32x32xf32>
    %2 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%padded, %cst : tensor<1x3x34x32xf32>, tensor<4x3x3x1xf32>) outs(%1 : tensor<1x4x32x32xf32>) -> tensor<1x4x32x32xf32>
    return %2 : tensor<1x4x32x32xf32>
  }
}

Hi Quinn, thanks for taking a look! Yes, they look the same, except for the tensor values.
I can give you access to the machine where this is failing..

Nice. Can you try removing the --iree-stream-resource-max-allocation-size=4294967295 flag also? And to confirm, this is reproducible with iree-benchmark-module and/or iree-run-module based on the above issue?

I am only seeing two allocations and they aren't particularly large:

%transient_buffer = hal.device.queue.alloca<%device : !hal.device> affinity(%c-1_i64) wait(%0) signal(%fence) pool(%c0_i64) type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage") : !hal.buffer{%c16384}
%transient_buffer_1 = hal.device.queue.alloca<%device : !hal.device> affinity(%c-1_i64) wait(%0) signal(%fence_0) pool(%c0_i64) type("DeviceVisible|DeviceLocal") usage("TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage") : !hal.buffer{%c13056}

@benvanik or @antiagainst might have more insights. The generated code doesn't look good (using a lot of function memory) but that doesn't seem to be where it's failing.

Thanks!

$ iree-compile --iree-hal-target-backends=vulkan-spirv --iree-stream-resource-index-bits=64 --iree-vm-target-truncate-unsupported-floats --iree-vm-target-index-bits=64 --iree-input-type=none --iree-vulkan-target-triple=rdna3-7900-windows conv.mlir -o=conv.vmfb     

$ iree-benchmark-module --device=vulkan --module=conv.vmfb --function=forward --input=1x3x32x32xf32
2023-11-13T10:12:30-05:00
Running C:\Users\haranand\AppData\Local\hatch\env\virtual\whisper-shark\lCn2R_CD\whisper-shark\Lib\site-packages\iree\_runtime_libs\iree-benchmark-module
Run on (24 X 4756 MHz CPU s)
CPU Caches:
  L1 Data 32 KiB (x12)
  L1 Instruction 32 KiB (x12)
  L2 Unified 1024 KiB (x12)
  L3 Unified 32768 KiB (x2)
C:\actions-runner\w\SRT\SRT\c\runtime\src\iree\hal\drivers\vulkan\native_allocator.cc:315: RESOURCE_EXHAUSTED; VK_ERROR_OUT_OF_HOST_MEMORY; vkAllocateMemory; while invoking native function hal.device.queue.alloca; while calling import;
[ 1]   native hal.device.queue.alloca:0 -
[ 0] bytecode module.forward:372 conv.mlir:20:10
      at conv.mlir:5:3

$ iree-run-module --device=vulkan --module=conv.vmfb --function=forward --input=1x3x32x32xf32
EXEC @forward
C:\actions-runner\w\SRT\SRT\c\runtime\src\iree\hal\drivers\vulkan\native_allocator.cc:315: RESOURCE_EXHAUSTED; VK_ERROR_OUT_OF_HOST_MEMORY; vkAllocateMemory; while invoking native function hal.device.queue.alloca; while calling import;
[ 1]   native hal.device.queue.alloca:0 -
[ 0] bytecode module.forward:372 conv.mlir:20:10
      at conv.mlir:5:3; invoking function 'forward'

$ iree-compile --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-truncate-unsupported-floats --iree-vm-target-index-bits=64 --iree-input-type=none --iree-vulkan-target-triple=rdna3-7900-windows conv.mlir -o=conv.vmfb

$ iree-benchmark-module --device=local-sync --module=conv.vmfb --function=forward --input=1x3x32x32xf32
2023-11-13T10:13:28-05:00
Running C:\Users\haranand\AppData\Local\hatch\env\virtual\whisper-shark\lCn2R_CD\whisper-shark\Lib\site-packages\iree\_runtime_libs\iree-benchmark-module
Run on (24 X 4717.96 MHz CPU s)
CPU Caches:
  L1 Data 32 KiB (x12)
  L1 Instruction 32 KiB (x12)
  L2 Unified 1024 KiB (x12)
  L3 Unified 32768 KiB (x2)
--------------------------------------------------------------------------------------------
Benchmark                                  Time             CPU   Iterations UserCounters...
--------------------------------------------------------------------------------------------
BM_forward/process_time/real_time      0.005 ms        0.006 ms       126965 items_per_second=184.584k/s

iree-compile-outputs-and-vmfb.zip @qedawkins

Ok I'm able to run your vmfb on my side with the same command, and the vm IR looks largely the same as what I'm seeing. Will have to try setting up on Windows I guess, so this might take a bit longer to reproduce.

Closing this, its now working on the remote PC I have. I removed and reconnected the GPU on PCI slot and connected it with a monitor.. Not really sure what happened :/, its working now..
Thanks all! cc: @qedawkins