iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Vulkan compile errors for llama model from sharktank

ScottTodd opened this issue · comments

What happened?

I'm following the examples in sharktank to export the f16 GGUF file from https://huggingface.co/SlyEcho/open_llama_3b_v2_gguf . When I try to compile through IREE CPU (--iree-hal-target-backends=llvm-cpu) I hit #17244, for Vulkan (--iree-hal-target-backends=vulkan-spirv), I hit these errors.

Steps to reproduce your issue

  1. Download open_llama_3b_v2_f16.mlir: https://sharkpublic.blob.core.windows.net/sharkpublic/scotttodd/issue_reports/open_llama_3b_v2_f16.mlir (or re-export with https://github.com/nod-ai/sharktank/tree/main/sharktank#examples)
  2. Run iree-compile open_llama_3b_v2_f16.mlir --iree-hal-target-backends=vulkan-spirv -o /tmp/open_llama_3b_v2_f16_vulkan.vmfb
  3. Observe errors:
error: 'spirv.IAdd' op operand #0 must be 8/16/32/64-bit integer or
vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 or
Cooperative Matrix of 8/16/32/64-bit integer values, but got 'i1'
              %21 = arith.addi %17, %20 : i1
             ^
open_llama_3b_v2_f16_vulkan\configured_module_prefill_bs4$async_dispatch_1.mlir:9:6:
error: 'func.func' op uses -127270912 bytes of shared memory; exceeded the limit of 16384 bytes
      func.func @prefill_bs4$async_dispatch_1_generic_4xDx3200_i64xf32() attributes {translation_info = #iree_codegen.translation_info<SPIRVBaseDistribute workgroup_size = [64, 1, 1]>} {
     ^

Full errors:

Click to expand full stderr output


module @module {
^
failed to translate executables
failed to translate executables
D:\dev\projects\iree-tmp\2024_05_llms\open_llama_3b_v2_f16_vulkan\configured_module_prefill_bs4$async_dispatch_0.mlir:36:10: error: 'spirv.IAdd' op operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16 or Cooperative Matrix of 8/16/32/64-bit integer values, but got 'i1'
          %21 = arith.addi %17, %20 : i1
         ^
D:\dev\projects\iree-tmp\2024_05_llms\open_llama_3b_v2_f16_vulkan\configured_module_prefill_bs4$async_dispatch_0.mlir:36:10: note: see current operation: %90 = "spirv.IAdd"(%78, %89) : (i1, i1) -> i1 D:\dev\projects\iree-tmp\2024_05_llms\open_llama_3b_v2_f16_vulkan\configured_module_prefill_bs4$async_dispatch_0.mlir:2:2: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce, api=Vulkan, #spirv.resource_limits>}>
  hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce, api=Vulkan, #spirv.resource_limits>}>) {
 ^
D:\dev\projects\iree-tmp\2024_05_llms\open_llama_3b_v2_f16_vulkan\configured_module_prefill_bs4$async_dispatch_0.mlir:2:2: note: see current operation:
"hal.executable.variant"() ({
  "hal.executable.export"() ({
  ^bb0(%arg0: !hal.device, %arg1: index):
    %0 = "arith.constant"() <{value = 1 : index}> : () -> index
    %1 = "affine.apply"(%arg1) <{map = affine_map<()[s0] -> (s0 ceildiv 64)>}> : (index) -> index
    "hal.return"(%1, %arg1, %0) : (index, index, index) -> ()
  }) {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>, #hal.interface.binding<0, 3>], layout = #hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer, ReadOnly>, <3, storage_buffer>]>]>, ordinal = 0 : index, sym_name = "prefill_bs4$async_dispatch_0_conv_4xDxD_i1xi64xf32xf32", workgroup_size = [64 : index, 1 : index, 1 : index]} : () -> ()
  "builtin.module"() ({
    "spirv.module"() <{addressing_model = #spirv.addressing_model, memory_model = #spirv.memory_model}> ({
      "spirv.GlobalVariable"() <{sym_name = "__builtin__LocalInvocationId__", type = !spirv.ptr, Input>}> {built_in = "LocalInvocationId"} : () -> ()
      "spirv.GlobalVariable"() <{sym_name = "__builtin__WorkgroupId__", type = !spirv.ptr, Input>}> {built_in = "WorkgroupId"} : () -> ()
      "spirv.GlobalVariable"() <{sym_name = "__push_constant_var__", type = !spirv.ptr [0])>, PushConstant>}> : () -> ()
      "spirv.GlobalVariable"() <{binding = 1 : i32, descriptor_set = 0 : i32, sym_name = "__resource_var_0_1_", type = !spirv.ptr, stride=8> [0])>, StorageBuffer>}> : () -> ()
      "spirv.GlobalVariable"() <{binding = 2 : i32, descriptor_set = 0 : i32, sym_name = "__resource_var_0_2_", type = !spirv.ptr [0])>, StorageBuffer>}> : () -> ()
      "spirv.GlobalVariable"() <{binding = 0 : i32, descriptor_set = 0 : i32, sym_name = "__resource_var_0_0_", type = !spirv.ptr [0])>, StorageBuffer>}> : () -> ()
      "spirv.GlobalVariable"() <{binding = 3 : i32, descriptor_set = 0 : i32, sym_name = "__resource_var_0_3_", type = !spirv.ptr [0])>, StorageBuffer>}> : () -> ()
      "spirv.func"() <{function_control = #spirv.function_control, function_type = () -> (), sym_name = "prefill_bs4$async_dispatch_0_conv_4xDxD_i1xi64xf32xf32"}> ({
        %0 = "spirv.Constant"() <{value = 0 : i32}> : () -> i32
        %1 = "spirv.Constant"() <{value = -64 : i32}> : () -> i32
        %2 = "spirv.Constant"() <{value = 64 : i32}> : () -> i32
        %3 = "spirv.Constant"() <{value = 819200 : i32}> : () -> i32
        %4 = "spirv.Constant"() <{value = 2048 : i32}> : () -> i32
        %5 = "spirv.Constant"() <{value = 4 : i32}> : () -> i32
        %6 = "spirv.Constant"() <{value = 0 : i32}> : () -> i32
        %7 = "spirv.Constant"() <{value = 0xFF800000 : f32}> : () -> f32
        %8 = "spirv.Constant"() <{value = 0 : i32}> : () -> i32
        %9 = "spirv.Constant"() <{value = 0 : i32}> : () -> i32
        %10 = "spirv.mlir.addressof"() <{variable = @__push_constant_var__}> : () -> !spirv.ptr [0])>, PushConstant>
        %11 = "spirv.AccessChain"(%10, %8, %9) : (!spirv.ptr [0])>, PushConstant>, i32, i32) -> !spirv.ptr
        %12 = "spirv.Load"(%11) : (!spirv.ptr) -> i32
        %13 = "spirv.mlir.addressof"() <{variable = @__resource_var_0_1_}> : () -> !spirv.ptr, stride=8> [0])>, StorageBuffer>
        %14 = "spirv.mlir.addressof"() <{variable = @__resource_var_0_2_}> : () -> !spirv.ptr [0])>, StorageBuffer>
        %15 = "spirv.IMul"(%12, %4) : (i32, i32) -> i32
        %16 = "spirv.IAdd"(%15, %3) : (i32, i32) -> i32
        %17 = "spirv.mlir.addressof"() <{variable = @__resource_var_0_0_}> : () -> !spirv.ptr [0])>, StorageBuffer>
        %18 = "spirv.IMul"(%12, %12) : (i32, i32) -> i32
        %19 = "spirv.IMul"(%18, %5) : (i32, i32) -> i32
        %20 = "spirv.mlir.addressof"() <{variable = @__resource_var_0_3_}> : () -> !spirv.ptr [0])>, StorageBuffer>
        %21 = "spirv.mlir.addressof"() <{variable = @__builtin__WorkgroupId__}> : () -> !spirv.ptr, Input>
        %22 = "spirv.Load"(%21) : (!spirv.ptr, Input>) -> vector<3xi32>
        %23 = "spirv.CompositeExtract"(%22) <{indices = [1 : i32]}> : (vector<3xi32>) -> i32
        %24 = "spirv.mlir.addressof"() <{variable = @__builtin__WorkgroupId__}> : () -> !spirv.ptr, Input>
        %25 = "spirv.Load"(%24) : (!spirv.ptr, Input>) -> vector<3xi32>
        %26 = "spirv.CompositeExtract"(%25) <{indices = [0 : i32]}> : (vector<3xi32>) -> i32
        %27 = "spirv.IMul"(%26, %1) : (i32, i32) -> i32
        %28 = "spirv.IAdd"(%12, %27) : (i32, i32) -> i32
        %29 = "spirv.GL.SMin"(%28, %2) : (i32, i32) -> i32
        %30 = "spirv.mlir.addressof"() <{variable = @__builtin__LocalInvocationId__}> : () -> !spirv.ptr, Input>
        %31 = "spirv.Load"(%30) : (!spirv.ptr, Input>) -> vector<3xi32>
        %32 = "spirv.CompositeExtract"(%31) <{indices = [0 : i32]}> : (vector<3xi32>) -> i32
        %33 = "spirv.Constant"() <{value = 64 : i32}> : () -> i32
        %34 = "spirv.mlir.addressof"() <{variable = @__builtin__LocalInvocationId__}> : () -> !spirv.ptr, Input>
        %35 = "spirv.Load"(%34) : (!spirv.ptr, Input>) -> vector<3xi32>
        %36 = "spirv.CompositeExtract"(%35) <{indices = [1 : i32]}> : (vector<3xi32>) -> i32
        %37 = "spirv.Constant"() <{value = 1 : i32}> : () -> i32
        "spirv.mlir.loop"() <{loop_control = #spirv.loop_control}> ({
          "spirv.Branch"(%36)[^bb1] : (i32) -> ()
        ^bb1(%38: i32):  // 2 preds: ^bb0, ^bb2
          %39 = "spirv.SLessThan"(%38, %5) : (i32, i32) -> i1
          "spirv.BranchConditional"(%39)[^bb2, ^bb3] <{operandSegmentSizes = array}> : (i1) -> ()
        ^bb2:  // pred: ^bb1
          "spirv.mlir.loop"() <{loop_control = #spirv.loop_control}> ({
            "spirv.Branch"(%32)[^bb1] : (i32) -> ()
          ^bb1(%41: i32):  // 2 preds: ^bb0, ^bb2
            %42 = "spirv.SLessThan"(%41, %29) : (i32, i32) -> i1
            "spirv.BranchConditional"(%42)[^bb2, ^bb3] <{operandSegmentSizes = array}> : (i1) -> ()
          ^bb2:  // pred: ^bb1
            %43 = "spirv.IMul"(%23, %4) : (i32, i32) -> i32
            %44 = "spirv.IAdd"(%43, %41) : (i32, i32) -> i32
            %45 = "spirv.IMul"(%26, %2) : (i32, i32) -> i32
            %46 = "spirv.IAdd"(%44, %45) : (i32, i32) -> i32
            %47 = "spirv.IAdd"(%46, %3) : (i32, i32) -> i32
            %48 = "spirv.Constant"() <{value = 0 : i32}> : () -> i32
            %49 = "spirv.Constant"() <{value = 0 : i32}> : () -> i32
            %50 = "spirv.Constant"() <{value = 1 : i32}> : () -> i32
            %51 = "spirv.Constant"() <{value = 4 : i32}> : () -> i32
            %52 = "spirv.SDiv"(%47, %51) : (i32, i32) -> i32
            %53 = "spirv.AccessChain"(%17, %48, %52) : (!spirv.ptr [0])>, StorageBuffer>, i32, i32) -> !spirv.ptr
            %54 = "spirv.Load"(%53) : (!spirv.ptr) -> i32
            %55 = "spirv.Constant"() <{value = 4 : i32}> : () -> i32
            %56 = "spirv.Constant"() <{value = 8 : i32}> : () -> i32
            %57 = "spirv.UMod"(%47, %55) : (i32, i32) -> i32
            %58 = "spirv.IMul"(%57, %56) : (i32, i32) -> i32
            %59 = "spirv.ShiftRightArithmetic"(%54, %58) : (i32, i32) -> i32
            %60 = "spirv.Constant"() <{value = 255 : i32}> : () -> i32
            %61 = "spirv.BitwiseAnd"(%59, %60) : (i32, i32) -> i32
            %62 = "spirv.Constant"() <{value = 24 : i32}> : () -> i32
            %63 = "spirv.ShiftLeftLogical"(%61, %62) : (i32, i32) -> i32
            %64 = "spirv.ShiftRightArithmetic"(%63, %62) : (i32, i32) -> i32
            %65 = "spirv.Constant"() <{value = 0 : i32}> : () -> i32
            %66 = "spirv.Constant"() <{value = 0 : i32}> : () -> i32
            %67 = "spirv.Constant"() <{value = 1 : i32}> : () -> i32
            %68 = "spirv.AccessChain"(%13, %65, %38) : (!spirv.ptr, stride=8> [0])>, StorageBuffer>, i32, i32) -> !spirv.ptr, StorageBuffer>
            %69 = "spirv.Load"(%68) : (!spirv.ptr, StorageBuffer>) -> vector<2xi32>
            %70 = "spirv.Constant"() <{value = 0 : i32}> : () -> i32
            %71 = "spirv.AccessChain"(%14, %70, %70) : (!spirv.ptr [0])>, StorageBuffer>, i32, i32) -> !spirv.ptr
            %72 = "spirv.Load"(%71) : (!spirv.ptr) -> f32
            %73 = "spirv.Constant"() <{value = 1 : i32}> : () -> i32
            %74 = "spirv.BitwiseAnd"(%64, %73) : (i32, i32) -> i32
            %75 = "spirv.IEqual"(%74, %73) : (i32, i32) -> i1
            %76 = "spirv.Constant"() <{value = false}> : () -> i1
            %77 = "spirv.Constant"() <{value = true}> : () -> i1
            %78 = "spirv.Select"(%75, %77, %76) : (i1, i1, i1) -> i1
            %79 = "spirv.IAdd"(%41, %45) : (i32, i32) -> i32
            %80 = "spirv.SLessThan"(%79, %0) : (i32, i32) -> i1
            %81 = "spirv.Constant"() <{value = -1 : i32}> : () -> i32
            %82 = "spirv.Constant"() <{value = 0 : i32}> : () -> i32
            %83 = "spirv.Select"(%80, %81, %82) : (i1, i32, i32) -> i32
            %84 = "spirv.CompositeExtract"(%69) <{indices = [0 : i32]}> : (vector<2xi32>) -> i32
            %85 = "spirv.CompositeExtract"(%69) <{indices = [1 : i32]}> : (vector<2xi32>) -> i32
            %86 = "spirv.UGreaterThanEqual"(%79, %84) : (i32, i32) -> i1
            %87 = "spirv.SGreaterThanEqual"(%83, %85) : (i32, i32) -> i1
            %88 = "spirv.IEqual"(%83, %85) : (i32, i32) -> i1
            %89 = "spirv.Select"(%88, %86, %87) : (i1, i1, i1) -> i1
            %90 = "spirv.IAdd"(%78, %89) : (i1, i1) -> i1
            %91 = "spirv.Select"(%90, %7, %72) : (i1, f32, f32) -> f32
            %92 = "spirv.IMul"(%38, %12) : (i32, i32) -> i32
            %93 = "spirv.IAdd"(%23, %92) : (i32, i32) -> i32
            %94 = "spirv.IMul"(%93, %12) : (i32, i32) -> i32
            %95 = "spirv.IAdd"(%79, %94) : (i32, i32) -> i32
            %96 = "spirv.Constant"() <{value = 0 : i32}> : () -> i32
            %97 = "spirv.Constant"() <{value = 0 : i32}> : () -> i32
            %98 = "spirv.Constant"() <{value = 1 : i32}> : () -> i32
            %99 = "spirv.AccessChain"(%20, %96, %95) : (!spirv.ptr [0])>, StorageBuffer>, i32, i32) -> !spirv.ptr
            "spirv.Store"(%99, %91) : (!spirv.ptr, f32) -> ()
            %100 = "spirv.IAdd"(%41, %33) : (i32, i32) -> i32
            "spirv.Branch"(%100)[^bb1] : (i32) -> ()
          ^bb3:  // pred: ^bb1
            "spirv.mlir.merge"() : () -> ()
          }) : () -> ()
          %40 = "spirv.IAdd"(%38, %37) : (i32, i32) -> i32
          "spirv.Branch"(%40)[^bb1] : (i32) -> ()
        ^bb3:  // pred: ^bb1
          "spirv.mlir.merge"() : () -> ()
        }) : () -> ()
        "spirv.Return"() : () -> ()
      }) {spirv.entry_point_abi = #spirv.entry_point_abi} : () -> ()
    }) : () -> ()
  }) {spirv.target_env = #spirv.target_env<#spirv.vce, api=Vulkan, #spirv.resource_limits>} : () -> ()
  "hal.executable.variant_end"() : () -> ()
}) {sym_name = "vulkan_spirv_fb", target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce, api=Vulkan, #spirv.resource_limits>}>} : () -> ()
D:\dev\projects\iree-tmp\2024_05_llms\open_llama_3b_v2_f16_vulkan\configured_module_prefill_bs4$async_dispatch_1.mlir:9:6: error: 'func.func' op uses -127270912 bytes of shared memory; exceeded the limit of 16384 bytes
      func.func @prefill_bs4$async_dispatch_1_generic_4xDx3200_i64xf32() attributes {translation_info = #iree_codegen.translation_info} {
     ^
D:\dev\projects\iree-tmp\2024_05_llms\open_llama_3b_v2_f16_vulkan\configured_module_prefill_bs4$async_dispatch_1.mlir:9:6: note: see current operation:
"func.func"() <{function_type = () -> (), sym_name = "prefill_bs4$async_dispatch_1_generic_4xDx3200_i64xf32"}> ({
  %0 = "arith.constant"() <{value = 3200 : index}> : () -> index
  %1 = "arith.constant"() <{value = 32000 : index}> : () -> index
  %2 = "arith.constant"() <{value = 0 : index}> : () -> index
  %3 = "arith.constant"() <{value = 32 : i64}> : () -> i64
  %4 = "hal.interface.constant.load"() {index = 0 : index} : () -> i32
  %5 = "hal.interface.constant.load"() {index = 1 : index} : () -> i32
  %6 = "hal.interface.constant.load"() {index = 2 : index} : () -> i32
  %7 = "hal.interface.constant.load"() {index = 3 : index} : () -> i32
  %8 = "arith.extui"(%5) : (i32) -> i64
  %9 = "arith.shli"(%8, %3) <{overflowFlags = #arith.overflow}> : (i64, i64) -> i64
  %10 = "arith.extui"(%4) : (i32) -> i64
  %11 = "arith.ori"(%10, %9) : (i64, i64) -> i64
  %12 = "arith.index_castui"(%11) {stream.alignment = 64 : index} : (i64) -> index
  %13 = "arith.extui"(%7) : (i32) -> i64
  %14 = "arith.shli"(%13, %3) <{overflowFlags = #arith.overflow}> : (i64, i64) -> i64
  %15 = "arith.extui"(%6) : (i32) -> i64
  %16 = "arith.ori"(%15, %14) : (i64, i64) -> i64
  %17 = "arith.index_castui"(%16) : (i64) -> index
  %18 = "hal.interface.binding.subspan"(%2) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 1 : i32, descriptor_type = #hal.descriptor_type, operandSegmentSizes = array, set = 0 : index} : (index) -> memref<32000x3200xf16, #hal.descriptor_type>
  "memref.assume_alignment"(%18) <{alignment = 64 : i32}> : (memref<32000x3200xf16, #hal.descriptor_type>) -> ()
  %19 = "hal.interface.binding.subspan"(%2, %17) {alignment = 64 : index, binding = 1 : index, descriptor_flags = 1 : i32, descriptor_type = #hal.descriptor_type, operandSegmentSizes = array, set = 0 : index} : (index, index) -> memref<4x?xi64, #hal.descriptor_type>
  "memref.assume_alignment"(%19) <{alignment = 64 : i32}> : (memref<4x?xi64, #hal.descriptor_type>) -> ()
  %20 = "hal.interface.binding.subspan"(%12, %17) {alignment = 64 : index, binding = 2 : index, descriptor_type = #hal.descriptor_type, operandSegmentSizes = array, set = 0 : index} : (index, index) -> memref<4x?x3200xf32, strided<[?, 3200, 1], offset: ?>, #hal.descriptor_type>
  "memref.assume_alignment"(%20) <{alignment = 64 : i32}> : (memref<4x?x3200xf32, strided<[?, 3200, 1], offset: ?>, #hal.descriptor_type>) -> ()
  %21 = "hal.interface.workgroup.id"() {dimension = 2 : index} : () -> index
  %22 = "hal.interface.workgroup.id"() {dimension = 1 : index} : () -> index
  %23 = "hal.interface.workgroup.id"() {dimension = 0 : index} : () -> index
  %24 = "affine.apply"(%23) <{map = affine_map<()[s0] -> (s0 * 64)>}> : (index) -> index
  %25 = "memref.subview"(%20, %21, %22, %24) <{operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array}> : (memref<4x?x3200xf32, strided<[?, 3200, 1], offset: ?>, #hal.descriptor_type>, index, index, index) -> memref<1x1x64xf32, strided<[?, 3200, 1], offset: ?>, #hal.descriptor_type>
  %26 = "memref.alloc"() <{alignment = 64 : i64, operandSegmentSizes = array}> : () -> memref<32000x3200xf32, #gpu.address_space>
  %27 = "gpu.thread_id"() <{dimension = #gpu}> : () -> index
  %28 = "gpu.block_dim"() <{dimension = #gpu}> : () -> index
  %29 = "gpu.thread_id"() <{dimension = #gpu}> : () -> index
  %30 = "gpu.block_dim"() <{dimension = #gpu}> : () -> index
  "scf.for"(%29, %1, %30) ({
  ^bb0(%arg0: index):
    "scf.for"(%27, %0, %28) ({
    ^bb0(%arg1: index):
      %37 = "memref.subview"(%18, %arg0, %arg1) <{operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array}> : (memref<32000x3200xf16, #hal.descriptor_type>, index, index) -> memref<1x1xf16, strided<[3200, 1], offset: ?>, #hal.descriptor_type>
      %38 = "memref.subview"(%26, %arg0, %arg1) <{operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array}> : (memref<32000x3200xf32, #gpu.address_space>, index, index) -> memref<1x1xf32, strided<[3200, 1], offset: ?>, #gpu.address_space>
      %39 = "memref.load"(%37, %2, %2) <{nontemporal = false}> : (memref<1x1xf16, strided<[3200, 1], offset: ?>, #hal.descriptor_type>, index, index) -> f16
      %40 = "arith.extf"(%39) : (f16) -> f32
      "memref.store"(%40, %38, %2, %2) <{nontemporal = false}> : (f32, memref<1x1xf32, strided<[3200, 1], offset: ?>, #gpu.address_space>, index, index) -> ()
      "scf.yield"() : () -> ()
    }) : (index, index, index) -> ()
    "scf.yield"() : () -> ()
  }) : (index, index, index) -> ()
  %31 = "memref.subview"(%19, %21, %22) <{operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array}> : (memref<4x?xi64, #hal.descriptor_type>, index, index) -> memref<1x1xi64, strided<[?, 1], offset: ?>, #hal.descriptor_type>
  %32 = "memref.subview"(%25, %27) <{operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array}> : (memref<1x1x64xf32, strided<[?, 3200, 1], offset: ?>, #hal.descriptor_type>, index) -> memref<1x1x1xf32, strided<[?, 3200, 1], offset: ?>, #hal.descriptor_type>
  %33 = "memref.load"(%31, %2, %2) <{nontemporal = false}> : (memref<1x1xi64, strided<[?, 1], offset: ?>, #hal.descriptor_type>, index, index) -> i64
  %34 = "arith.index_cast"(%33) : (i64) -> index
  %35 = "affine.apply"(%23, %27) <{map = affine_map<()[s0, s1] -> (s0 * 64 + s1)>}> : (index, index) -> index
  %36 = "memref.load"(%26, %34, %35) <{nontemporal = false}> : (memref<32000x3200xf32, #gpu.address_space>, index, index) -> f32
  "memref.store"(%36, %32, %2, %2, %2) <{nontemporal = false}> : (f32, memref<1x1x1xf32, strided<[?, 3200, 1], offset: ?>, #hal.descriptor_type>, index, index, index) -> ()
  "func.return"() : () -> ()
}) {translation_info = #iree_codegen.translation_info} : () -> ()
D:\dev\projects\iree-tmp\2024_05_llms\open_llama_3b_v2_f16_vulkan\configured_module_prefill_bs4$async_dispatch_1.mlir:2:2: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce, api=Vulkan, #spirv.resource_limits>}>
  hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce, api=Vulkan, #spirv.resource_limits>}>) {
 ^
D:\dev\projects\iree-tmp\2024_05_llms\open_llama_3b_v2_f16_vulkan\configured_module_prefill_bs4$async_dispatch_1.mlir:2:2: note: see current operation:
"hal.executable.variant"() ({
  "hal.executable.export"() ({
  ^bb0(%arg0: !hal.device, %arg1: index):
    %0 = "arith.constant"() <{value = 4 : index}> : () -> index
    %1 = "arith.constant"() <{value = 50 : index}> : () -> index
    "hal.return"(%1, %arg1, %0) : (index, index, index) -> ()
  }) {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>], layout = #hal.pipeline.layout, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>, ordinal = 0 : index, sym_name = "prefill_bs4$async_dispatch_1_generic_4xDx3200_i64xf32"} : () -> ()
  "builtin.module"() ({
    "func.func"() <{function_type = () -> (), sym_name = "prefill_bs4$async_dispatch_1_generic_4xDx3200_i64xf32"}> ({
      %0 = "arith.constant"() <{value = 3200 : index}> : () -> index
      %1 = "arith.constant"() <{value = 32000 : index}> : () -> index
      %2 = "arith.constant"() <{value = 0 : index}> : () -> index
      %3 = "arith.constant"() <{value = 32 : i64}> : () -> i64
      %4 = "hal.interface.constant.load"() {index = 0 : index} : () -> i32
      %5 = "hal.interface.constant.load"() {index = 1 : index} : () -> i32
      %6 = "hal.interface.constant.load"() {index = 2 : index} : () -> i32
      %7 = "hal.interface.constant.load"() {index = 3 : index} : () -> i32
      %8 = "arith.extui"(%5) : (i32) -> i64
      %9 = "arith.shli"(%8, %3) <{overflowFlags = #arith.overflow}> : (i64, i64) -> i64
      %10 = "arith.extui"(%4) : (i32) -> i64
      %11 = "arith.ori"(%10, %9) : (i64, i64) -> i64
      %12 = "arith.index_castui"(%11) {stream.alignment = 64 : index} : (i64) -> index
      %13 = "arith.extui"(%7) : (i32) -> i64
      %14 = "arith.shli"(%13, %3) <{overflowFlags = #arith.overflow}> : (i64, i64) -> i64
      %15 = "arith.extui"(%6) : (i32) -> i64
      %16 = "arith.ori"(%15, %14) : (i64, i64) -> i64
      %17 = "arith.index_castui"(%16) : (i64) -> index
      %18 = "hal.interface.binding.subspan"(%2) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 1 : i32, descriptor_type = #hal.descriptor_type, operandSegmentSizes = array, set = 0 : index} : (index) -> memref<32000x3200xf16, #hal.descriptor_type>
      "memref.assume_alignment"(%18) <{alignment = 64 : i32}> : (memref<32000x3200xf16, #hal.descriptor_type>) -> ()
      %19 = "hal.interface.binding.subspan"(%2, %17) {alignment = 64 : index, binding = 1 : index, descriptor_flags = 1 : i32, descriptor_type = #hal.descriptor_type, operandSegmentSizes = array, set = 0 : index} : (index, index) -> memref<4x?xi64, #hal.descriptor_type>
      "memref.assume_alignment"(%19) <{alignment = 64 : i32}> : (memref<4x?xi64, #hal.descriptor_type>) -> ()
      %20 = "hal.interface.binding.subspan"(%12, %17) {alignment = 64 : index, binding = 2 : index, descriptor_type = #hal.descriptor_type, operandSegmentSizes = array, set = 0 : index} : (index, index) -> memref<4x?x3200xf32, strided<[?, 3200, 1], offset: ?>, #hal.descriptor_type>
      "memref.assume_alignment"(%20) <{alignment = 64 : i32}> : (memref<4x?x3200xf32, strided<[?, 3200, 1], offset: ?>, #hal.descriptor_type>) -> ()
      %21 = "hal.interface.workgroup.id"() {dimension = 2 : index} : () -> index
      %22 = "hal.interface.workgroup.id"() {dimension = 1 : index} : () -> index
      %23 = "hal.interface.workgroup.id"() {dimension = 0 : index} : () -> index
      %24 = "affine.apply"(%23) <{map = affine_map<()[s0] -> (s0 * 64)>}> : (index) -> index
      %25 = "memref.subview"(%20, %21, %22, %24) <{operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array}> : (memref<4x?x3200xf32, strided<[?, 3200, 1], offset: ?>, #hal.descriptor_type>, index, index, index) -> memref<1x1x64xf32, strided<[?, 3200, 1], offset: ?>, #hal.descriptor_type>
      %26 = "memref.alloc"() <{alignment = 64 : i64, operandSegmentSizes = array}> : () -> memref<32000x3200xf32, #gpu.address_space>
      %27 = "gpu.thread_id"() <{dimension = #gpu}> : () -> index
      %28 = "gpu.block_dim"() <{dimension = #gpu}> : () -> index
      %29 = "gpu.thread_id"() <{dimension = #gpu}> : () -> index
      %30 = "gpu.block_dim"() <{dimension = #gpu}> : () -> index
      "scf.for"(%29, %1, %30) ({
      ^bb0(%arg0: index):
        "scf.for"(%27, %0, %28) ({
        ^bb0(%arg1: index):
          %37 = "memref.subview"(%18, %arg0, %arg1) <{operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array}> : (memref<32000x3200xf16, #hal.descriptor_type>, index, index) -> memref<1x1xf16, strided<[3200, 1], offset: ?>, #hal.descriptor_type>
          %38 = "memref.subview"(%26, %arg0, %arg1) <{operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array}> : (memref<32000x3200xf32, #gpu.address_space>, index, index) -> memref<1x1xf32, strided<[3200, 1], offset: ?>, #gpu.address_space>
          %39 = "memref.load"(%37, %2, %2) <{nontemporal = false}> : (memref<1x1xf16, strided<[3200, 1], offset: ?>, #hal.descriptor_type>, index, index) -> f16
          %40 = "arith.extf"(%39) : (f16) -> f32
          "memref.store"(%40, %38, %2, %2) <{nontemporal = false}> : (f32, memref<1x1xf32, strided<[3200, 1], offset: ?>, #gpu.address_space>, index, index) -> ()
          "scf.yield"() : () -> ()
        }) : (index, index, index) -> ()
        "scf.yield"() : () -> ()
      }) : (index, index, index) -> ()
      %31 = "memref.subview"(%19, %21, %22) <{operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array}> : (memref<4x?xi64, #hal.descriptor_type>, index, index) -> memref<1x1xi64, strided<[?, 1], offset: ?>, #hal.descriptor_type>
      %32 = "memref.subview"(%25, %27) <{operandSegmentSizes = array, static_offsets = array, static_sizes = array, static_strides = array}> : (memref<1x1x64xf32, strided<[?, 3200, 1], offset: ?>, #hal.descriptor_type>, index) -> memref<1x1x1xf32, strided<[?, 3200, 1], offset: ?>, #hal.descriptor_type>
      %33 = "memref.load"(%31, %2, %2) <{nontemporal = false}> : (memref<1x1xi64, strided<[?, 1], offset: ?>, #hal.descriptor_type>, index, index) -> i64
      %34 = "arith.index_cast"(%33) : (i64) -> index
      %35 = "affine.apply"(%23, %27) <{map = affine_map<()[s0, s1] -> (s0 * 64 + s1)>}> : (index, index) -> index
      %36 = "memref.load"(%26, %34, %35) <{nontemporal = false}> : (memref<32000x3200xf32, #gpu.address_space>, index, index) -> f32
      "memref.store"(%36, %32, %2, %2, %2) <{nontemporal = false}> : (f32, memref<1x1x1xf32, strided<[?, 3200, 1], offset: ?>, #hal.descriptor_type>, index, index, index) -> ()
      "func.return"() : () -> ()
    }) {translation_info = #iree_codegen.translation_info} : () -> ()
  }) : () -> ()
  "hal.executable.variant_end"() : () -> ()
}) {sym_name = "vulkan_spirv_fb", target = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce, api=Vulkan, #spirv.resource_limits>}>} : () -> ()

What component(s) does this issue relate to?

Compiler

Version information

a075013

Additional context

I might be able to work around the shared memory issue with --iree-vulkan-target-triple matching my GPU, but we need a much better story with default flags. Haven't seen the spirv.IAdd issue with i1 before.

Tried with --iree-vulkan-target-triple=turing-unknown-unknown and got a similar error:

open_llama_3b_v2_f16_vulkan\configured_module_prefill_bs4$async_dispatch_1.mlir:9:6: error:
'func.func' op uses -127270912 bytes of shared memory; exceeded the limit of 49152 bytes

The source for configured_module_prefill_bs4$async_dispatch_1.mlir is:

hal.executable public @prefill_bs4$async_dispatch_1 {
  hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, PhysicalStorageBufferAddresses, VariablePointers, VariablePointersStorageBuffer, DotProduct, DotProductInputAll, DotProductInput4x8BitPacked, DotProductInput4x8Bit, CooperativeMatrixKHR], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer, SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, api=Vulkan, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], min_subgroup_size = 32, max_subgroup_size = 32, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 8, n_size = 8, k_size = 32, a_type = i8, b_type = i8, c_type = i32, result_type = i32, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>>]>>}>) {
    hal.executable.export public @prefill_bs4$async_dispatch_1_generic_4xDx3200_i64xf32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 4, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]} {
    ^bb0(%arg0: !hal.device, %arg1: index):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1
      hal.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @prefill_bs4$async_dispatch_1_generic_4xDx3200_i64xf32() attributes {translation_info = #iree_codegen.translation_info<SPIRVBaseDistribute workgroup_size = [32, 1, 1]>} {
        %c0 = arith.constant 0 : index
        %c32_i64 = arith.constant 32 : i64
        %0 = hal.interface.constant.load[0] : i32
        %1 = hal.interface.constant.load[1] : i32
        %2 = hal.interface.constant.load[2] : i32
        %3 = hal.interface.constant.load[3] : i32
        %4 = arith.extui %1 : i32 to i64
        %5 = arith.shli %4, %c32_i64 : i64
        %6 = arith.extui %0 : i32 to i64
        %7 = arith.ori %6, %5 : i64
        %8 = arith.index_castui %7 {stream.alignment = 64 : index} : i64 to index
        %9 = arith.extui %3 : i32 to i64
        %10 = arith.shli %9, %c32_i64 : i64
        %11 = arith.extui %2 : i32 to i64
        %12 = arith.ori %11, %10 : i64
        %13 = arith.index_castui %12 : i64 to index
        %14 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32000x3200xf16>>
        %15 = flow.dispatch.workload.ordinal %13, 0 : index
        %16 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4x?xi64>>{%15}
        %17 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%8) : !flow.dispatch.tensor<writeonly:tensor<4x?x3200xf32>>{%15}
        %18 = flow.dispatch.tensor.load %14, offsets = [0, 0], sizes = [32000, 3200], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32000x3200xf16>> -> tensor<32000x3200xf16>
        %19 = flow.dispatch.tensor.load %16, offsets = [0, 0], sizes = [4, %15], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4x?xi64>>{%15} -> tensor<4x?xi64>
        %20 = tensor.empty(%15) : tensor<4x?x3200xf32>
        %21 = tensor.empty() : tensor<32000x3200xf32>
        %22 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%18 : tensor<32000x3200xf16>) outs(%21 : tensor<32000x3200xf32>) {
        ^bb0(%in: f16, %out: f32):
          %24 = arith.extf %in : f16 to f32
          linalg.yield %24 : f32
        } -> tensor<32000x3200xf32>
        %23 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%19 : tensor<4x?xi64>) outs(%20 : tensor<4x?x3200xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 32], [1, 1, 1]]>} {
        ^bb0(%in: i64, %out: f32):
          %24 = arith.index_cast %in : i64 to index
          %25 = linalg.index 2 : index
          %extracted = tensor.extract %22[%24, %25] : tensor<32000x3200xf32>
          linalg.yield %extracted : f32
        } -> tensor<4x?x3200xf32>
        flow.dispatch.tensor.store %23, %17, offsets = [0, 0, 0], sizes = [4, %15, 3200], strides = [1, 1, 1] : tensor<4x?x3200xf32> -> !flow.dispatch.tensor<writeonly:tensor<4x?x3200xf32>>{%15}
        return
      }
    }
  }
}

The configured_module_prefill_bs4$async_dispatch_1.mlir issue (using too much shared memory) may go away with llvm/torch-mlir#3277.

I still see the spirv.IAdd issue with configured_module_prefill_bs4$async_dispatch_0.mlir:

hal.executable public @prefill_bs4$async_dispatch_0 {
  hal.executable.variant public @vulkan_spirv_fb target(<"vulkan-spirv", "vulkan-spirv-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.6, [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformClustered, GroupNonUniformQuad, PhysicalStorageBufferAddresses, VariablePointers, VariablePointersStorageBuffer, DotProduct, DotProductInputAll, DotProductInput4x8BitPacked, DotProductInput4x8Bit, CooperativeMatrixKHR], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product, SPV_KHR_storage_buffer_storage_class, SPV_KHR_physical_storage_buffer, SPV_KHR_variable_pointers, SPV_KHR_cooperative_matrix]>, api=Vulkan, NVIDIA:DiscreteGPU, #spirv.resource_limits<max_compute_shared_memory_size = 49152, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024, 1024, 64], min_subgroup_size = 32, max_subgroup_size = 32, cooperative_matrix_properties_khr = [#spirv.coop_matrix_props_khr<m_size = 8, n_size = 8, k_size = 32, a_type = i8, b_type = i8, c_type = i32, result_type = i32, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>>, #spirv.coop_matrix_props_khr<m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>>]>>}>) {
    hal.executable.export public @prefill_bs4$async_dispatch_0_conv_4xDxD_i1xi64xf32xf32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer, ReadOnly>, <3, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>, #hal.interface.binding<0, 3>]} {
    ^bb0(%arg0: !hal.device, %arg1: index):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice %arg1
      hal.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @prefill_bs4$async_dispatch_0_conv_4xDxD_i1xi64xf32xf32() attributes {translation_info = #iree_codegen.translation_info<SPIRVBaseDistribute workgroup_size = [32, 1, 1]>} {
        %c0 = arith.constant 0 : index
        %c819200 = arith.constant 819200 : index
        %c32_i64 = arith.constant 32 : i64
        %cst = arith.constant 0xFF800000 : f32
        %0 = hal.interface.constant.load[0] : i32
        %1 = hal.interface.constant.load[1] : i32
        %2 = arith.extui %1 : i32 to i64
        %3 = arith.shli %2, %c32_i64 : i64
        %4 = arith.extui %0 : i32 to i64
        %5 = arith.ori %4, %3 : i64
        %6 = arith.index_castui %5 : i64 to index
        %7 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4xi64>>
        %8 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<f32>>
        %9 = flow.dispatch.workload.ordinal %6, 0 : index
        %10 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c819200) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x1x?x2048xi8>>{%9}
        %11 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4x?x?xf32>>{%9, %9}
        %12 = flow.dispatch.tensor.load %7, offsets = [0], sizes = [4], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4xi64>> -> tensor<4xi64>
        %13 = flow.dispatch.tensor.load %8, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:tensor<f32>> -> tensor<f32>
        %14 = tensor.empty(%9, %9) : tensor<4x?x?xf32>
        %15 = flow.dispatch.tensor.load %10, offsets = [0, 0, 0, 0], sizes = [1, 1, %9, %9], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x1x?x2048xi8>>{%9} -> tensor<?x?xi8>
        %16 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>, affine_map<(d0, d1, d2) -> ()>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15, %12, %13 : tensor<?x?xi8>, tensor<4xi64>, tensor<f32>) outs(%14 : tensor<4x?x?xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[32, 1, 32], [1, 1, 1]]>} {
        ^bb0(%in: i8, %in_0: i64, %in_1: f32, %out: f32):
          %17 = arith.trunci %in : i8 to i1
          %18 = linalg.index 2 : index
          %19 = arith.index_cast %18 : index to i64
          %20 = arith.cmpi sge, %19, %in_0 : i64
          %21 = arith.addi %17, %20 : i1
          %22 = arith.select %21, %cst, %in_1 : f32
          linalg.yield %22 : f32
        } -> tensor<4x?x?xf32>
        flow.dispatch.tensor.store %16, %11, offsets = [0, 0, 0], sizes = [4, %9, %9], strides = [1, 1, 1] : tensor<4x?x?xf32> -> !flow.dispatch.tensor<writeonly:tensor<4x?x?xf32>>{%9, %9}
        return
      }
    }
  }
}

Copying from nod-ai/sharktank#22 (comment):

For spriv-vulkan backend here's the minimal repro

func.func @torch_add(%arg0: !torch.vtensor<[1,1,?,?],i1>, %arg1: !torch.vtensor<[4,1,1,?],i1>) -> !torch.vtensor<[4, 1, ?, ?],i1> {
   %int1 = torch.constant.int 1
   %2 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[1,1,?,?],i1>, !torch.vtensor<[4,1,1,?],i1>, !torch.int -> !torch.vtensor<[4,1,?,?],i1>
   return %2 : !torch.vtensor<[4,1,?,?],i1>
 }

error: spirv.IAdd op operand #0 must be 8/16/32/64-bit integer but got i1 .