iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Sharding into multiple IREE logical devices or queues with peered memory

sogartar opened this issue · comments

This is a proposal on how to approach distribution of an MLIR program across multiple homogeneous physical devices that have peered memory. It is intended to improve memory usage under NUMA. Devices/Queues have direct peer-to-peer memory accesses during kernel execution. We want to explicitly shard tensor data and operations and assign them to devices/queues.

Tiling and distribution of work on multiple compute units treated as a single device and queue differs for this problem in that there each work item does not matter to what compute unit it is assigned. The compiler does not have to explicitly resolve the assignment. It is left up to the runtime. Also memory allocations does not have to distinguish where they are performed.

Here we use a program with Mesh dialect sharding annotations as an input. The original intent of the Mesh dialect is to utilize MPI-like collectives instead of placement operations like extracting and inserting tensor slices. We would like to expand its scope to support the latter case also. The algorithm that produces the sharding annotation is out-of-scope here and it is assumed that the program is fully annotated.

Below is an example of an annotated program that does matrix multiplication. It is sharded on a 1D process mesh of size 2.

    2x3              3x2             2x2       
+-----------+   +------------+   +-----------+
| process 0 |   | fully      |   | process 0 |
+-----------+ X | replicated | = +-----------+
| process 1 |   |            |   | process 1 |
+-----------+   +------------+   +-----------+
mesh.mesh @mesh(shape = 2, peered_memory_axes = [0])

func.func @main(
  %arg0 : tensor<2x3xf32>,
  %arg1 : tensor<3x2xf32>,
  %out_dps: tensor<2x2xf32>
) -> tensor<2x2xf32> {
  %arg0_sharded = mesh.shard %arg0 to <@mesh, [[0]]> : tensor<2x3xf32>
  %arg1_sharded = mesh.shard %arg1 to <@mesh, []> : tensor<3x2xf32>
  %out_dps_sharded = mesh.shard %arg1 to <@mesh, [0]> : tensor<2x2xf32>

  %res = linalg.matmul ins(%arg0_sharded, %arg1_sharded :
    tensor<2x3xf32>, tensor<3x2xf32>)
    outs(%out_dps_sharded : tensor<2x2xf32>) -> tensor<2x2xf32>

  %res_sharded = mesh.shard %res to <@mesh, [[0]]> : tensor<2x2xf32>

  return %res_sharded : tensor<2x2xf32>
}

Note that the mesh shape may be dynamic.

For example

mesh.mesh @mesh(shape = ?x?x?, peered_memory_axes = [1, 2])

Here we have a 3D mesh with all dimensions being dynamic. It has peered memory across mesh axes 1 and 2. The mesh is partitioned into disjoint subsets where there is uniform memory only within a subset. Processes (h, i, j) and (k, l, m) have peered memory for h = k and distinct memory for h != k. In the immediate future we will be concerned with 1D meshes with peered memory.


As a first step we transform the program from annotations using mesh.shard operations to a distributed tensor form that utilizes the tensor encoding attribute.

// Shard tensor axis 0 on mesh axis 0.
#shard_0_on_0 = #mesh.shard<@mesh, [[0]]>
// The tensor is replicated on all shards.
#fully_replicate = #mesh.shard<@mesh, []>

mesh.mesh @mesh(shape = 2, peered_memory_axes = [0])

func.func @main(
  %arg0 : tensor<2x3xf32, #shard_0_on_0>,
  %arg1 : tensor<3x2xf32, #fully_replicate>,
  %out_dps: tensor<2x2xf32, #shard_0_on_0>
) -> tensor<2x2xf32, #shard_0_on_0> {
  %res = linalg.matmul ins(%arg0, %arg1 :
    tensor<2x3xf32, #shard_0_on_0>,
    tensor<3x2xf32, #fully_replicate>)
    outs(%out_dps : tensor<2x2xf32, #shard_0_on_0>)
    -> tensor<2x2xf32, #shard_0_on_0>

  return %res : tensor<2x2xf32, #shard_0_on_0>
}

The usage of the encoding attribute may clash with other uses like sparsity encoding. We access the mesh sharding attribute through a type interface that hides the encoding attribute detail to make future changes easier.

In this step we would insert resharding if required by the annotations. In general under peered memory when access uniformity is high we would like to avoid resharding and have the subsequent operation be able to consume the original sharding.


After converting to distributed tensors we shard the operations, where we describe what each process is doing. One approach is to utilize a mesh.for_all operation. All processes share the same code. It is similar to the flow.dispatch.workgroups.

#shard_0_on_0 = #mesh.shard<@mesh, [[0]]>
#fully_replicate = #mesh.shard<@mesh, [[]]>

mesh.mesh @mesh(shape = 2, peered_memory_axes = [0])

func.func @main(
  %arg0 : tensor<2x3xf32, #shard_0_on_0>,
  %arg1 : tensor<3x2xf32, #fully_replicate>,
  %out_dps: tensor<2x2xf32, #shard_0_on_0>
) -> tensor<2x2xf32, #shard_0_on_0> {

  %res = mesh.for_all on @mesh ->
    tensor<2x3xf32, #shard_0_on_0> {
    %c2 = arith.constant 2 : index
    %proc_idx = mesh.process_multi_index on @mesh : index
    %slice_offset = arith.muli %proc_idx, %c2 : index
    %arg0_slice = tensor.extract_slice %arg0[%slice_offset, 0][1, 3][1, 1] :
      tensor<2x3xf32, #shard_0_on_0> to tensor<1x3xf32>
    %out_dps_slice = tensor.extract_slice %out_dps[%slice_offset, 0][1, 2][1, 1] :
      tensor<2x2xf32, #shard_0_on_0> to tensor<1x2xf32>
    
    %res_slice = linalg.matmul ins(%arg0_slice, %arg1 :
      tensor<1x3xf32>,
      tensor<3x2xf32, #fully_replicate>)
      outs(%out_dps_slice : tensor<2x2xf32, #shard_0_on_0>)
      -> tensor<1x2xf32>
    // We return the shard for this process.
    return %res_slice : tensor<1x2xf32>
  }

  return %res : tensor<2x2xf32, #shard_0_on_0>
}

Instead of a new mesh.for_all operation we can introduce mesh.for_one. It describes what each process is doing separately. With this approach we combine it with scf.forall to describe the operations for the entire mesh.

#shard_0_on_0 = #mesh.shard<@mesh, [[0]]>
#fully_replicate = #mesh.shard<@mesh, [[]]>

mesh.mesh @mesh(shape = 2, peered_memory_axes = [0])

func.func @main(
  %arg0 : tensor<2x3xf32, #shard_0_on_0>,
  %arg1 : tensor<3x2xf32, #fully_replicate>,
  %out_dps: tensor<2x2xf32, #shard_0_on_0>
) -> tensor<2x2xf32, #shard_0_on_0> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %mesh_size = mesh.mesh_shape @mesh : index
  %res = scf.forall (%proc_idx) in (%mesh_size)
    shared_outs(%out = %out_dps) -> (tensor<2x2xf32, #shard_0_on_0>) {
      scf.forall.in_parallel {
      %res = mesh.for_one %proc_idx on @mesh {
        %c2 = arith.constant 2 : index
        %slice_offset = arith.muli %proc_idx, %c2 : index
        %arg0_slice = tensor.extract_slice %arg0[%slice_offset, 0][1, 3][1, 1] :
          tensor<2x3xf32, #shard_0_on_0> to tensor<1x3xf32>
        %out_slice = tensor.extract_slice %out[%slice_offset, 0][1, 2][1, 1] :
          tensor<2x2xf32, #shard_0_on_0> to tensor<1x2xf32>
        
        %res_slice = linalg.matmul ins(%arg0_slice, %arg1 :
          tensor<1x3xf32>,
          tensor<3x2xf32, #fully_replicate>)
          outs(%out_slice : tensor<2x2xf32, #shard_0_on_0>)
          -> tensor<1x2xf32>

        tensor.parallel_insert_slice %res_slice into
          %out[%slice_offset, 0][1, 2][1, 1] :
          tensor<1x2xf32> into tensor<2x2xf32, #shard_0_on_0>
      }
    }
  }

  return %res : tensor<2x2xf32, #shard_0_on_0>
}

After all operations have been sharded we need convert each distributed tensor into a list of tensors.

In the case of a dynamic mesh we need to use a dynamic list. We can utilize tensor<?x...>. This dynamic nature will require propagation all the way to the runtime. The VM would need to handle dynamic lists of buffers.

In the case of a static mesh we can use the values that correspond to shards directly. In this case we have to also unroll all loops that run operations on the mesh. Bellow is an example of doing that.

mesh.mesh @mesh(shape = 2, peered_memory_axes = [0])

func.func @main(
  %arg0_shard0 : tensor<1x3xf32>,
  %arg0_shard1 : tensor<1x3xf32>,
  %arg1_shard0 : tensor<3x2xf32>,
  %arg1_shard1 : tensor<3x2xf32>,
  %out_dps_shard0: tensor<1x2xf32>,
  %out_dps_shard1: tensor<1x2xf32>,
) -> (tensor<1x2xf32>, tensor<1x2xf32>) {
  // Shard 0.
  %res_shard0 = mesh.for_one 0 on @mesh {
    %res = linalg.matmul ins(%arg0_shard0, %arg1_shard0 :
      tensor<1x3xf32>, tensor<3x2xf32>)
      outs(%out_dps_shard0 : tensor<1x2xf32>)
      -> tensor<1x2xf32>
    return %res : tensor<1x2xf32>
  }

  // Shard 1.
  %res_shard1 = mesh.for_one 1 on @mesh {
    %res = linalg.matmul ins(%arg0_shard1, %arg1_shard1 :
      tensor<1x3xf32>, tensor<3x2xf32>)
      outs(%out_dps_shard2 : tensor<1x2xf32>)
      -> tensor<1x2xf32>
    return %res : tensor<1x2xf32>
  }

  return %res_shard0, %res_shard0 : tensor<1x2xf32>, tensor<1x2xf32>
}

This example is manicured for illustration. In real-world programs a mesh process may have to ingest operands that require extracting slices that fall on the boundary of tensor shards. If we can reliably fuse mesh.for_one operations to avoid copying, operand resharding can be handled as a separate concern.

At this point we are ready to convert the remaining mesh.for_one ops into IREE specific descriptions of device/queue affinity.


We could shard after dispatch region formation, but before transforming into workgroups. It is expected that the user may have provided some sharding annotations throughout the program. Then we form the dispatch regions, complete the sharding annotations and lastly partition into Mesh processes. The user-provided sharding annotations may hinder fusion of the dispatch regions. Can we place sharding annotations inside dispatch regions and what does that even mean? This approach also requires that sharding propagation handles dispatch regions. This drives the requirement for solving the more general problem of propagating through a composite operation.

The other option is to shard before forming the dispatch regions. In this scenario during dispatch region formation we will have to handle a lot of tensor placement (insert/extract) operations due to resharding. It would also have to handle device/queue placement.

We could expand the scope of flow.dispatch.workgroups to handle additional dimensions that are mapped to devices/queues.


We need to provide the user functionality to shard/unshard function arguments/results. This should be coupled with function variants that reserve the original signature and handle all sharding inside.

func.func @main_shard_arg0(%arg0 : tensor<2x3xf32>)
  -> tensor<2x3xf32, #shard_0_on_0> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %mesh_size = mesh.mesh_shape @mesh : index
  %out_dps = tensor.empty : tensor<2x3xf32, #shard_0_on_0>
  %res = scf.forall (%proc_idx) in (%mesh_size)
    shared_outs(%out = %out_dps) -> (tensor<2x3xf32, #shard_0_on_0>) {
      scf.forall.in_parallel {
      %res = mesh.for_one %proc_idx on @mesh {
        %c2 = arith.constant 2 : index
        %slice_offset = arith.muli %proc_idx, %c2 : index
        %slice = tensor.extract_slice %arg0[%slice_offset, 0][1, 3][1, 1] :
          tensor<2x3xf32> to tensor<1x3xf32>

        tensor.parallel_insert_slice %slice into
          %out[%slice_offset, 0][1, 3][1, 1] :
          tensor<1x3xf32> into tensor<2x3xf32, #shard_0_on_0>
      }
    }
  }

  return %res : tensor<2x3xf32, #shard_0_on_0>
}

func.func @main(%arg0 : tensor<2x3xf32>, ...) ... {
  %arg0_sharded = func.call @main_shard_arg0(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32, #shard_0_on_0>
  ...
  %res = func.call @main_shrded(%arg0_sharded, ...) ...
  // unshard %res and return
}

The resulting IR can be handled as the rest of the program to convert each distributed tensor into a list of tensors.

How to stream the tensor shards from the host directly into the destination slices without having to allocate the original unsharded tensor?


There is a potential optimization for some operations that have high computational complexity like matrix multiplication. When tiled, in order to avoid read multiple times the same tile from a slow memory region in the kernel, each such tile can be copied closer to the compute unit. We don't want to do that before starting computation, but to interleave it into the kernel.

I would like for some feedback on what and if this makes sense.

@benvanik, @stellaraccident could you take a look?

Great examples - really helps ground things! I will have to take a longer read as there's lots to unpack - but I think there's a layering issue here I need to put my finger on.

The immediate feedback is that we don't want to use tensor encodings for placements - we need to keep those for data layout/encodings - putting mesh dialect-specific encoding behavior between frontends and the rest of the compiler makes it practically impossible to carry the information we need. We also need to be able to mutate layouts at several points in the lowering pipeline and cannot practically preserve mixed semantics outside of very specific slices of the pipeline (which then ends up with a lot of trickery to convert into/out-of without losing information/breaking). So any solution we have will need to use ops instead of attributes to indicate sharding/transfers.

Dynamic VM lists of distributed tensors and such also raises some flags for me as that's going to not be possible in flow/stream as they exist today. Nor will usage of scf.forall and other functionality be trivial (though it'd be cool for other purposes). This could be a year+ of work and a lot of refactoring/rewriting and my feeling is that we're going to want to find solutions that let us progress much earlier/more incrementally. That may mean not using the mesh dialect - I'm not quite sure what it's adding here relative to the cost to support it.

Thanks for the very explicit examples. I'm still digesting.

I think what I'm struggling with is at what level we break the spend abstraction. With this approach, it is very late, hanging on to the dtensor paradigm through most of the pipeline.

Whereas in frameworks like tf dtensor and pytorch fsdp, it happens very early, at least for the parts running at the lowest level of the hierarchy on the same bus. Doing it late like this works for the homogenous distributed case. But I wonder how it would interplay to just start with some higher level operators that indicated device placement and movement (since that is already needed for heterogenous).

I don't think I'm saying this perfectly -- but I think there may be two concerns here, not just one and I'm trying to put my finger in it.

The immediate feedback is that we don't want to use tensor encodings for placements - we need to keep those for data layout/encodings - putting mesh dialect-specific encoding behavior between frontends and the rest of the compiler makes it practically impossible to carry the information we need.
@benvanik

I thought it is natural to put it in the tensor type as it describes the data. Isn't it like a layout? It is a layout of the tensor across the mesh. Regarding conflicts with other encoding information, doesn't this argument apply to any other use of this attribute? No one is supposed to use it as it may conflict with other uses. On the syntactical level one solution is when there is a conflict to use a dictionary of encodings. The more important and complicated question how to handle interactions between specific encodings would have to be solved no matter where we put the sharding information.

Dynamic VM lists of distributed tensors and such also raises some flags for me as that's going to not be possible in flow/stream as they exist today. Nor will usage of scf.forall and other functionality be trivial (though it'd be cool for other purposes).
@benvanik

This case is only relevant for a dynamic number of devices. I don't know how to solve it without dynamic lists and for loops unless we opt for an SPMD form. Which would have to describe referencing other devices' shards and use-def dependencies. Collectives provide this implicitly.
In the case of static number of devices the loops can be unrolled and dynamic lists are not needed.

That may mean not using the mesh dialect - I'm not quite sure what it's adding here relative to the cost to support it.
@benvanik

I can't also make the call. I am not sure in what other concise way one is to describe sharding over homogeneous devices.

I think what I'm struggling with is at what level we break the spend abstraction. With this approach, it is very late, hanging on to the dtensor paradigm through most of the pipeline.
@stellaraccident

It is intended that the input IR would have partial or complete sharding annotations using the mesh.shard operation, not encoding it in the tensor type. I introduced the sharding encoding into the tensor type, because it seemed more natural as it describes the data. After sharding is complete there will be no more dtensors, just the shards as plain tensors.
On the runtime side if we are not using a dynamic set of devices, there should be no new requirements. There will be no runtime concept of a dtensor. Towards the end I proposed a way to let the compiler handle sharding/unsharding of arguments/results of public module functions. Or you can call the sharded variant if you want.

But I wonder how it would interplay to just start with some higher level operators that indicated device placement and movement (since that is already needed for heterogenous).
@stellaraccident

For a heterogenous system the device mesh representation would not be sufficient. I that case you would probably want a hierarchical representation like a file system.
E.g.

    -------- node_0 ----------  ...  node_k
   /        /      \          \
cpu_0 ... cpu_m    gpu_0 ... gpu_n
                  /    \
         chiplet_0 ... chiplet_q

Then we can assign properties on each level, like peered memory or communication bandwidth.