iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Codegen] TileAndDistributeToWorkgroups for operations with multiple results and related producers

qedawkins opened this issue · comments

Tile + fuse via the tiling interface currently runs into an issue when considering fusion of operations that takes operands with related producers. For example,

%empty = tensor.empty() : tensor<?xindex>
%0:2 = linalg.generic outs(%empty, %empty) {iterator_types = ["parallel"]} {
  ^bb0(%out: i32, %out1: i32):
    %1 = linalg.index 0 : index
    %2 = arith.constant 128 : index
    %3 = arith.remsi %1, %2 : index
    linalg.yield %1, %3 : index, index
}
%1 = linalg.add ins(%0#0, %0#1 : tensor<?xindex>, tensor<?xindex>) outs(%empty : tensor<?xindex>)

Here, to do tile + fuse, we essentially have two options: Tile the producer and fuse the consumer, or vice-versa. If we do the former, we end up with something like this following:

%empty = tensor.empty() : tensor<?xindex>
%loop: = scf.forall ... iter_args(%init0 = %empty, %init1 = %empty)
  %slice0 = tensor.extract_slice %init0 ... : tensor<?xindex>
  %slice1 = tensor.extract_slice %init1 ... : tensor<?xindex>
  %0:2 = linalg.generic outs(%slice0, %slice1) {iterator_types = ["parallel"]} {
    ^bb0(%out: i32, %out1: i32):
      %1 = linalg.index 0 : index
      %2 = arith.constant 128 : index
      %3 = arith.remsi %1, %2 : index
      linalg.yield %1, %3 : index, index
  }
  scf.forall.in_parallel {
    tensor.parallel_insert_slice %0#0 into %init0
    tensor.parallel_insert_slice %0#1 into %init1
  }
}
%1 = linalg.add ins(%loop#0, %loop#1 : tensor<?xindex>, tensor<?xindex>) outs(%empty : tensor<?xindex>)

With the above, to fuse %1 into the producer loop, we need to prove that each tensor.parallel_insert_slice is inserting the same slice of %1's iteration space. Given that insertion offsets are specified dynamically, this is quite difficult to do in general.

The other option of tiling the consumer first produces something like this:

%empty = tensor.empty() : tensor<?xindex>
%0:2 = linalg.generic outs(%empty, %empty) {iterator_types = ["parallel"]} {
  ^bb0(%out: i32, %out1: i32):
    %1 = linalg.index 0 : index
    %2 = arith.constant 128 : index
    %3 = arith.remsi %1, %2 : index
    linalg.yield %1, %3 : index, index
}
%loop = scf.forall ... iter_args(%init = %empty)
  %slice0 = tensor.extract_slice %0#1 ... : tensor<?xindex>
  %slice1 = tensor.extract_slice %0#2 ... : tensor<?xindex>
  %1 = linalg.add ins(%slice0, %slice1 : tensor<?xindex>, tensor<?xindex>)
  scf.forall.in_parallel {
    tensor.parallel_insert_slice %1 into %init
  }
}

To fuse in this case, we either must do the same as the consumer fusion case by relating the extracted results of %0 to its iteration space, or we must rematerialize %0. This is problematic when %0 (a placeholder for a more interesting operation in the above example) is expensive. There is some hope of CSE fixing the duplication after the fact, however this remains a place where the tiling interface struggles today because relating slices of operands to slices of iteration spaces is difficult.

For this particular case because %0#2 and %0#1 essentially need the same tile of the producer, current expected way this works is that you fuse both uses individually and CSE will fix the duplication. This was an explicit decision to reduce "number of things to track during fusion".

Just to note, there is another route that we can take. The reason there are issues is that we are doing fusion greedily. Instead, we could do any analysis to determine how we should tile and then do a "one-shot" tiling, where we assume things match up and tile them into a scf.forall loop. This is how vector distribution does it and doesn't have problems with multiple results.

For this particular case because %0#2 and %0#1 essentially need the same tile of the producer, current expected way this works is that you fuse both uses individually and CSE will fix the duplication. This was an explicit decision to reduce "number of things to track during fusion".

It's worth noting then that this is justification for requiring that tile + fuse roots with multiple results should avoid having consumers.

Just to note, there is another route that we can take. The reason there are issues is that we are doing fusion greedily. Instead, we could do any analysis to determine how we should tile and then do a "one-shot" tiling, where we assume things match up and tile them into a scf.forall loop. This is how vector distribution does it and doesn't have problems with multiple results.

I have strong reasons to not go this route. The fusion is greedy cause it has to "tile everything". We are distributing to workgroups, so everything has to tile to workgroups (even if it means duplication of computation).

For this particular case because %0#2 and %0#1 essentially need the same tile of the producer, current expected way this works is that you fuse both uses individually and CSE will fix the duplication. This was an explicit decision to reduce "number of things to track during fusion".

It's worth noting then that this is justification for requiring that tile + fuse roots with multiple results should avoid having consumers.

I am not sure I follow this.

I think this is something we can work out when we get to it. I was just trying to say that tile + fuse rooted on a producer op needs to be considered very carefully, and we might end up needing propagation of lowering configs anyway.