iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Missing propagation for `unpack -> collapse_shape` to `collpase_shape -> unpack`.

hanhanW opened this issue · comments

It stops the fusion for unpack + consumers. E.g., we should be able to swap unpack and collapse_shape because it is just folding unit dims away.

  func.func @foo(%arg0: tensor<1x1024x1024x16x16xf32>) -> tensor<16384x16384xf32> {
    %0 = tensor.empty() : tensor<1x16384x16384xf32>
    %unpack = tensor.unpack %arg0 outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 16] into %0 : tensor<1x1024x1024x16x16xf32> -> tensor<1x16384x16384xf32>
    %collapsed = tensor.collapse_shape %unpack [[0, 1], [2]] : tensor<1x16384x16384xf32> into tensor<16384x16384xf32>
    %1 = tensor.empty() : tensor<16384x16384xf32>
    %2 = linalg.softmax dimension(1) ins(%collapsed : tensor<16384x16384xf32>) outs(%1 : tensor<16384x16384xf32>) -> tensor<16384x16384xf32>
    return %2 : tensor<16384x16384xf32>
  }

Where is this collapse shape coming from. There might be a uniform way of handling this in the reshape propagation passes later on.

I don't know. It is here after set encoding. A sequence of linalg ops are raised to softmax op in GlobalOptimization stage. Are we able to push down reshape ops on named op? It looks not easy to me, so I think we can implement a (unpack, collapse_shape) propagation pattern in this case.

I don't know. It is here after set encoding. A sequence of linalg ops are raised to softmax op in GlobalOptimization stage. Are we able to push down reshape ops on named op? It looks not easy to me, so I think we can implement a (unpack, collapse_shape) propagation pattern in this case.

The propogation patterns are implemented for Linalg ops, but we can add propagation patterns for other ops as well. I'd like to consolidate in one place all the propagation patterns if possible. We can still add those patterns, but we should be able to use them in the reshape propagation passes.