Support MLIR
tucan9389 opened this issue · comments
TODO
- Graph Data Structure
- Tokenizer
- Parser
- Setting up mlir.Model with parsed object
- Support basic type (f32, f16, i16, etc.)
- Operation successor-list (by skipping)
- Operation dictionary-properties (by skipping)
- Operation region-list (by skipping)
- Operation dictionary-attribute
- Support starting with
#
keyword #1 (comment) - Support starting with
!
keyword #1 (comment) - Support basic block (starting with
^
in operation level) #1 (comment) - test with tensorflow’s mlir files (1.55k files)
- test with all mlir files on Github (58.9k files)
Runnable ratio with mlir.Parser
and mlir.Tokenizer
Update at 23.05.10
repo name | tensorflow | mlir-hlo | openxla/iree | llvm-project | llvm/circt |
---|---|---|---|---|---|
# of mlir files | 1.5k | 4.4k | 1.1k | 1.3k | 500 |
success ratio | 0.523 | 0.049 | 0.088 | 0.106 | 0.044 |
netron term vs. MLIR term
netron term | MLIR term | Graph term | Note |
---|---|---|---|
mlir.Graph |
Region or Function | Graph or Model | |
mlir.Parameter |
Value (Operation's Operand or Result) | Edge | |
mlir.Argument |
Value (Operation's Operand or Result) | Edge | |
mlir.Node |
Operation | Node | |
mlir.Attributes |
Operation's Attribute | Node's metadata | |
mlir.Tensor |
Constant Operation | Kind of node | |
mlir.TensorType |
Edge's metadata | ||
mlir.TensorShape |
Edge's metadata |
MLIR graph structure
netron graph structure
Test the mlir.Parser
and mlir.Tokenizer
Add following code bottom of the source/mlir.js
and run node source/mlir.js
.
const input = `
func.func @main(%arg0: f32) -> f32 {
%result:2 = "foo_div"() : () -> (f32, i32)
// Pretty form that defines a unique name for each result.
%foo, %bar = "foo_div"() : () -> (f32, i32)
// Invoke a TensorFlow function called tf.scramble with two inputs
// and an attribute "fruit" stored in properties.
%2 = "tf.scramble"(%result#0, %bar) <{fruit = "banana"}> : (f32, i32) -> f32
// Invoke an operation with some discardable attributes
%foo2, %bar2 = "foo_div"() {some_attr = "value", other_attr = 42 : i64} : () -> (f32, i32)
return %bar2: f32
}
`
const decoder = text.Decoder.open(input);
const parser = new mlir.Parser(decoder);
const obj = parser.read();
console.log(JSON.stringify(obj, null, 2));
cc. @chococigar
Example MLIR (stablehlo_sample.mlir)
https://github.com/openxla/stablehlo/blob/main/docs/spec.md
stablehlo.func @main(
%image: tensor<28x28xf32>,
%weights: tensor<784x10xf32>,
%bias: tensor<1x10xf32>
) -> tensor<1x10xf32> {
%0 = "stablehlo.reshape"(%image) : (tensor<28x28xf32>) -> tensor<1x784xf32>
%1 = "stablehlo.dot"(%0, %weights) : (tensor<1x784xf32>, tensor<784x10xf32>) -> tensor<1x10xf32>
%2 = "stablehlo.add"(%1, %bias) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
%3 = "stablehlo.constant"() { value = dense<0.0> : tensor<1x10xf32> } : () -> tensor<1x10xf32>
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"stablehlo.return"(%4): (tensor<1x10xf32>) -> ()
}
Examples
- input: it's the content of the input file having
.mlir
extension- intermediate output: with the intermediate output, we can make the
mlir.Graph
andmlir.Node
, etc. to visualize the graph
issue_1043.mlir (onnx dialect, 7 lines)
- input: https://gist.github.com/tucan9389/942268b6131152acc213d55ab50015d8
- intermediate output: https://gist.github.com/tucan9389/ad3d7d9e5312b4477bf6e149e2e91e30
stablehlo_sample.mlir (stablehlo dialect, 14 lines)
- input: https://gist.github.com/tucan9389/f2204a2291de89c30f084749bef860b3
- intermediate output: https://gist.github.com/tucan9389/86a078580476fd8e3f5b8200cc1fe589
examples.mnist_xla.mlir (tf dialect, mhlo dialect, 208 lines)
Test Inputs for mlir.Tokenizer
module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 440 : i32}, tf_saved_model.semantics} {
"tf_saved_model.global_tensor"() {is_mutable, sym_name = "__sm_node4__optimizer.iter", tf_saved_model.exported_names = [], type = tensor<i64>, value = dense<0> : tensor<i64>} : () -> ()
"tf_saved_model.global_tensor"() {sym_name = "__sm_node6__optimizer.learning_rate", tf_saved_model.exported_names = [], type = tensor<f32>, value = dense<0.00999999977> : tensor<f32>} : () -> ()
func @__inference_predict_3320(%arg0: tensor<32x28x28x1xf32> {tf._user_specified_name = "inputs", tf_saved_model.index_path = [0]}, %arg1: tensor<32x1xf32> {tf._user_specified_name = "targets", tf_saved_model.index_path = [1]}, %arg2: tensor<!tf.resource<tensor<5x5x1x32xf32>>> {tf_saved_model.bound_input = @__sm_node17__model.conv1.kernel}, %arg3: tensor<!tf.resource<tensor<5x5x32x32xf32>>> {tf_saved_model.bound_input = @__sm_node26__model.conv2.kernel}, %arg4: tensor<!tf.resource<tensor<1568x1024xf32>>> {tf_saved_model.bound_input = @__sm_node39__model.dense1.kernel}, %arg5: tensor<!tf.resource<tensor<1024xf32>>> {tf_saved_model.bound_input = @__sm_node40__model.dense1.bias}, %arg6: tensor<!tf.resource<tensor<1024x10xf32>>> {tf_saved_model.bound_input = @__sm_node49__model.dense2.kernel}, %arg7: tensor<!tf.resource<tensor<10xf32>>> {tf_saved_model.bound_input = @__sm_node50__model.dense2.bias}, %arg8: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @__sm_node6__optimizer.learning_rate}, %arg9: tensor<!tf.resource<tensor<i64>>> {tf_saved_model.bound_input = @__sm_node4__optimizer.iter}) -> (tensor<f32> {tf_saved_model.index_path = []}) attributes {tf._input_shapes = [#tf.shape<32x28x28x1>, #tf.shape<32x1>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>], tf.signature.is_stateful, tf_saved_model.exported_names = ["predict"]} {
%0 = mhlo.constant dense<3.125000e-02> : tensor<32x10xf32>
%0 = mhlo.constant dense<3.125000e-02> : tensor<32x10xf32>
%1 = mhlo.constant dense<3.200000e+01> : tensor<f32>
%2 = mhlo.constant dense<1> : tensor<i64>
%14 = "tf.Cast"(%arg2) {Truncate = false} : (tensor<!tf.resource<tensor<5x5x1x32xf32>>>) -> tensor<!tf.resource>
%27 = "mhlo.convolution"(%arg0, %22) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<2> : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<32x28x28x1xf32>, tensor<5x5x1x32xf32>) -> tensor<32x28x28x32xf32>
%28 = mhlo.maximum %27, %11 : tensor<32x28x28x32xf32>
%29 = "mhlo.reduce_window"(%28, %8) ( {
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>): // no predecessors
%130 = mhlo.maximum %arg10, %arg11 : tensor<f32>
"mhlo.return"(%130) : (tensor<f32>) -> ()
}) {padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<32x28x28x32xf32>, tensor<f32>) -> tensor<32x14x14x32xf32>
%30 = "mhlo.convolution"(%29, %21) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<2> : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<32x14x14x32xf32>, tensor<5x5x32x32xf32>) -> tensor<32x14x14x32xf32>
"tf.AssignVariableOp"(%20, %126) : (tensor<!tf.resource>, tensor<*xi64>) -> ()
%127 = "mhlo.reduce"(%68, %12) ( {
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>): // no predecessors
%130 = mhlo.add %arg10, %arg11 : tensor<f32>
"mhlo.return"(%130) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<32xf32>, tensor<f32>) -> tensor<f32>
%128 = mhlo.divide %127, %1 : tensor<f32>
%129 = "mhlo.select"(%13, %12, %128) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %129 : tensor<f32>
}
}
// Compute A*B using an implementation of multiply kernel and print the
// result using a TensorFlow op. The dimensions of A and B are partially
// known. The shapes are assumed to match.
func.func @mul(%A: tensor<100x?xf32>, %B: tensor<?x50xf32>) -> (tensor<100x50xf32>) {
// Compute the inner dimension of %A using the dim operation.
%n = memref.dim %A, 1 : tensor<100x?xf32>
// Allocate addressable "buffers" and copy tensors %A and %B into them.
%A_m = memref.alloc(%n) : memref<100x?xf32>
memref.tensor_store %A to %A_m : memref<100x?xf32>
%B_m = memref.alloc(%n) : memref<?x50xf32>
memref.tensor_store %B to %B_m : memref<?x50xf32:<>>
// Call function @multiply passing memrefs as arguments,
// and getting returned the result of the multiplication.
%C_m = call @multiply(%A_m, %B_m)
: (memref<100x?xf32>, memref<?x50xf32>) -> (memref<100x50xf32>)
memref.dealloc %A_m : memref<100x?xf32>
memref.dealloc %B_m : memref<?x50xf32>
// Load the buffer data into a higher level "tensor" value.
%C = memref.tensor_load %C_m : memref<100x50xf32>
memref.dealloc %C_m : memref<100x50xf32>
// Call TensorFlow built-in function to print the result tensor.
"tf.Print"(%C){message: "mul result"} : (tensor<100x50xf32>) -> (tensor<100x50xf32>)
return %C : tensor<100x50xf32>
}
module {
func @my_function(%arg0: memref<32x32xf32>) attributes {affine_map_attr = affine_map<(d0, d1) -> (d0 + d1, d0 - d1)>} {
// function body
}
func @my_function() attributes {array_attr = [1, 2, 3]} {
// function body
}
func @my_function() attributes {nested_attr = {inner_attr1 = true, inner_attr2 = 3.14 : f64}} {
// function body
}
func @my_function(%arg0: f32, %arg1: f32) -> f32 attributes {attr1 = "value", attr2 = 42 : i32, attr3 = dense<[1, 2, 3]> : tensor<3xi32>} {
// function body
}
stablehlo.func @main(%image: tensor<28x28xf32>, %weights: tensor<784x10xf32>,%bias: tensor<1x10xf32>) -> tensor<1x10xf32> {
%99 = "mhlo.broadcast_in_dim"(%41) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<5x5x32x32xf32>
%98 = "mhlo.broadcast_in_dim"(%41) : (tensor<f32>) -> tensor<5x5x32x32xf32> {broadcast_dimensions = dense<> : tensor<0xi64>}
"stablehlo.return"(%1): (tensor<1x10xf32>) -> ()
}
func @example(%arg1: i32, %arg2: i32) -> (i32) { }
func @example(%arg1: i32, %arg2: i32) -> (i32, i32) { }
func @example(%arg1: i32, %arg2: i32) -> i32 { }
}
module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 440 : i32}, tf_saved_model.semantics} {
"tf_saved_model.global_tensor"() {is_mutable, sym_name = "__sm_node4__optimizer.iter", tf_saved_model.exported_names = [], type = tensor<i64>, value = dense<0> : tensor<i64>} : () -> ()
"tf_saved_model.global_tensor"() {sym_name = "__sm_node6__optimizer.learning_rate", tf_saved_model.exported_names = [], type = tensor<f32>, value = dense<0.00999999977> : tensor<f32>} : () -> ()
func @__inference_predict_3320(%arg0: tensor<32x28x28x1xf32> {tf._user_specified_name = "inputs", tf_saved_model.index_path = [0]}, %arg1: tensor<32x1xf32> {tf._user_specified_name = "targets", tf_saved_model.index_path = [1]}, %arg2: tensor<!tf.resource<tensor<5x5x1x32xf32>>> {tf_saved_model.bound_input = @__sm_node17__model.conv1.kernel}, %arg3: tensor<!tf.resource<tensor<5x5x32x32xf32>>> {tf_saved_model.bound_input = @__sm_node26__model.conv2.kernel}, %arg4: tensor<!tf.resource<tensor<1568x1024xf32>>> {tf_saved_model.bound_input = @__sm_node39__model.dense1.kernel}, %arg5: tensor<!tf.resource<tensor<1024xf32>>> {tf_saved_model.bound_input = @__sm_node40__model.dense1.bias}, %arg6: tensor<!tf.resource<tensor<1024x10xf32>>> {tf_saved_model.bound_input = @__sm_node49__model.dense2.kernel}, %arg7: tensor<!tf.resource<tensor<10xf32>>> {tf_saved_model.bound_input = @__sm_node50__model.dense2.bias}, %arg8: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @__sm_node6__optimizer.learning_rate}, %arg9: tensor<!tf.resource<tensor<i64>>> {tf_saved_model.bound_input = @__sm_node4__optimizer.iter}) -> (tensor<f32> {tf_saved_model.index_path = []}) attributes {tf._input_shapes = [#tf.shape<32x28x28x1>, #tf.shape<32x1>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>, #tf.shape<*>], tf.signature.is_stateful, tf_saved_model.exported_names = ["predict"]} {
%0 = mhlo.constant dense<3.125000e-02> : tensor<32x10xf32>
%0 = mhlo.constant dense<3.125000e-02> : tensor<32x10xf32>
%1 = mhlo.constant dense<3.200000e+01> : tensor<f32>
%2 = mhlo.constant dense<1> : tensor<i64>
%14 = "tf.Cast"(%arg2) {Truncate = false} : (tensor<!tf.resource<tensor<5x5x1x32xf32>>>) -> tensor<!tf.resource>
%27 = "mhlo.convolution"(%arg0, %22) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<2> : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<32x28x28x1xf32>, tensor<5x5x1x32xf32>) -> tensor<32x28x28x32xf32>
%28 = mhlo.maximum %27, %11 : tensor<32x28x28x32xf32>
%29 = "mhlo.reduce_window"(%28, %8) ( {
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>): // no predecessors
%130 = mhlo.maximum %arg10, %arg11 : tensor<f32>
"mhlo.return"(%130) : (tensor<f32>) -> ()
}) {padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<32x28x28x32xf32>, tensor<f32>) -> tensor<32x14x14x32xf32>
%30 = "mhlo.convolution"(%29, %21) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<2> : tensor<2x2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<32x14x14x32xf32>, tensor<5x5x32x32xf32>) -> tensor<32x14x14x32xf32>
"tf.AssignVariableOp"(%20, %126) : (tensor<!tf.resource>, tensor<*xi64>) -> ()
%127 = "mhlo.reduce"(%68, %12) ( {
^bb0(%arg10: tensor<f32>, %arg11: tensor<f32>): // no predecessors
%130 = mhlo.add %arg10, %arg11 : tensor<f32>
"mhlo.return"(%130) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<32xf32>, tensor<f32>) -> tensor<f32>
%128 = mhlo.divide %127, %1 : tensor<f32>
%129 = "mhlo.select"(%13, %12, %128) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %129 : tensor<f32>
}
}
no args, no parens operation
Done: be9a3df
// CHECK: use_two_sep_ops_empty_tensor_list
func.func @use_two_sep_ops_empty_tensor_list() {
%one = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%elem_shape = "tf.Const"() {value = dense<[-1, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
%size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
// CHECK: EmptyTensorList
// CHECK: EmptyTensorList
// CHECK: EmptyTensorList
%tl = "tf.EmptyTensorList"(%elem_shape, %size) : (tensor<2xi32>, tensor<i32>) -> tensor<!tf_type.variant<tensor<?x1xf32>>>
%elem_1 = "tf._FirstOp"() : () -> tensor<8x1xf32>
%tl_set_item = "tf.TensorListSetItem"(%tl, %one, %elem_1) : (tensor<!tf_type.variant<tensor<?x1xf32>>>, tensor<i32>, tensor<8x1xf32>) -> tensor<!tf_type.variant<tensor<?x1xf32>>>
%elem_2 = "tf._SecondOp"() : () -> tensor<16x1xf32>
%tl_set_item2 = "tf.TensorListSetItem"(%tl, %one, %elem_2) : (tensor<!tf_type.variant<tensor<?x1xf32>>>, tensor<i32>, tensor<16x1xf32>) -> tensor<!tf_type.variant<tensor<?x1xf32>>>
func.return
}
while
Done: 1e06958
:integer-literal
of op-result:- region-list:
op_name
(
args
)
({...}, {...})
:
function-type
func.func @while_region_op_two_sep_args_empty_tensor_list() {
%one = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%elem_shape = "tf.Const"() {value = dense<[-1, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
%size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
%tl = "tf.EmptyTensorList"(%elem_shape, %size) : (tensor<2xi32>, tensor<i32>) -> tensor<!tf_type.variant<tensor<?x1xf32>>>
%while:2 = "tf.WhileRegion"(%tl, %tl) ({
^bb0(%barg1: tensor<!tf_type.variant<tensor<?x1xf32>>>, %barg2: tensor<!tf_type.variant<tensor<?x1xf32>>>): // no predeceessors
%cond = "tf.false"():()-> tensor<i1>
"tf.Yield"(%cond) : (tensor<i1>) -> ()
}, {
^bb0(%barg1: tensor<!tf_type.variant<tensor<?x1xf32>>>, %barg2: tensor<!tf_type.variant<tensor<?x1xf32>>>): // no predeceessors
"tf.Yield"(%barg1, %barg2) : (tensor<!tf_type.variant<tensor<?x1xf32>>>, tensor<!tf_type.variant<tensor<?x1xf32>>>) -> ()
}) {is_stateless = false} : (tensor<!tf_type.variant<tensor<?x1xf32>>>, tensor<!tf_type.variant<tensor<?x1xf32>>>) -> (tensor<!tf_type.variant<tensor<?x1xf32>>>, tensor<!tf_type.variant<tensor<?x1xf32>>>)
func.return
}
Done: 1e06958
// https://github.com/tensorflow/tensorflow/blob/2d9d9c053c51ae417c9abfd4e0ea903e447d6952/tensorflow/compiler/mlir/tensorflow/tests/hoist_loop_invariant.mlir#L19
func.func @hoist_loop_invariant(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
%cst_0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
%0:2 = "tf.WhileRegion"(%arg0, %arg1) ({
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
%1 = "tf.OpA"() {is_stateless = true} : () -> tensor<i1>
"tf.Yield"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg2: tensor<i32>, %arg3: tensor<i32>):
%cst_1 = "tf.Const"() { value = dense<0> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.Add"(%cst_1, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%2 = "tf.Mul"(%1, %cst_1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%3 = "tf.AddV2"(%arg2, %1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%4 = "tf.Div"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"tf.Yield"(%3, %4) : (tensor<i32>, tensor<i32>) -> ()
}) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
return %0#0, %0#1 : tensor<i32>, tensor<i32>
}
sharp keyword
- access for a variable having multiple variables (
%result#0
) - Attribute Value Aliases
- Dialect Attribute Values
https://mlir.llvm.org/docs/LangRef/
access for a variable having multiple variables (
%result#0
)
%2 = "tf.scramble"(%result#0, %bar) <{fruit = "banana"}> : (f32, i32) -> f32
https://mlir.llvm.org/docs/LangRef/
Attribute Value Aliases
#map = affine_map<(d0) -> (d0 + 10)>
// Using the original attribute.
%b = affine.apply affine_map<(d0) -> (d0 + 10)> (%a)
// Using the attribute alias.
%b = affine.apply #map(%a)
https://mlir.llvm.org/docs/LangRef/
Dialect Attribute Values
// A string attribute.
#foo<string<"">>
// A complex attribute.
#foo<"a123^^^" + bar>
// A string attribute.
#foo.string<"">
// CHECK-LABEL: func @skip_noncompiled_reduce_dataset
func.func @skip_noncompiled_reduce_dataset(
%arg0 : tensor<!tf_type.variant>,
%arg1: tensor<i64>
) {
// CHECK: tf.ReduceDataset
%1 = "tf.ReduceDataset"(%arg0, %arg1) {
Targuments = [],
Tstate = [i64], device = "",
f = @__reduce_func_0, f._tf_data_function = true,
output_shapes = [#tf_type.shape<>],
output_types = [i64], use_inter_op_parallelism = true } : (tensor<!tf_type.variant>, tensor<i64>) -> (tensor<i64>)
func.return
}
various op syntax in MLIR Ref doc
https://mlir.llvm.org/docs/LangRef/
Done: cc5daab
// An operation that produces two results.
// The results of %result can be accessed via the <name> `#` <opNo> syntax.
%result:2 = "foo_div"() : () -> (f32, i32)
// Pretty form that defines a unique name for each result.
%foo, %bar = "foo_div"() : () -> (f32, i32)
// Invoke a TensorFlow function called tf.scramble with two inputs
// and an attribute "fruit" stored in properties.
%2 = "tf.scramble"(%result#0, %bar) <{fruit = "banana"}> : (f32, i32) -> f32
// Invoke an operation with some discardable attributes
%foo, %bar = "foo_div"() {some_attr = "value", other_attr = 42 : i64} : () -> (f32, i32)
block syntax in MLIR Ref doc
https://mlir.llvm.org/docs/LangRef/
func.func @simple(i64, i1) -> i64 {
^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
cf.br ^bb3(%a: i64) // Branch passes %a as the argument
^bb2:
%b = arith.addi %a, %a : i64
cf.br ^bb3(%b: i64) // Branch passes %b as the argument
// ^bb3 receives an argument, named %c, from predecessors
// and passes it on to bb4 along with %a. %a is referenced
// directly from its defining operation and is not passed through
// an argument of ^bb3.
^bb3(%c: i64):
cf.br ^bb4(%c, %a : i64, i64)
^bb4(%d : i64, %e : i64):
%0 = arith.addi %d, %e : i64
return %0 : i64 // Return is also a terminator.
}
graph region in MLIR Ref doc
Done: cc5daab
"test.graph_region"() ({ // A Graph region
%1 = "op1"(%1, %3) : (i32, i32) -> (i32) // OK: %1, %3 allowed here
%2 = "test.ssacfg_region"() ({
%5 = "op2"(%1, %2, %3, %4) : (i32, i32, i32, i32) -> (i32) // OK: %1, %2, %3, %4 all defined in the containing region
}) : () -> (i32)
%3 = "op2"(%1, %4) : (i32, i32) -> (i32) // OK: %4 allowed here
%4 = "op3"(%1) : (i32) -> (i32)
}) : () -> ()
exclamation mark(!
)
https://mlir.llvm.org/docs/LangRef/#type-aliases
Type Aliases
!avx_m128 = vector<4 x f32>
// Using the original type.
"foo"(%x) : vector<4 x f32> -> ()
// Using the type alias.
"foo"(%x) : !avx_m128 -> ()
https://mlir.llvm.org/docs/LangRef/#dialect-types
Dialect Aliases
// A tensorflow string type.
!tf<string>
// A type with complex components.
!foo<something<abcd>>
// An even more complex type.
!foo<"a123^^^" + bar>
func.func @single_state_single_dataset_type_no_arguments(
%arg0: tensor<!tf_type.variant>,
%arg1: tensor<i64>
) {
%1 = "tf.ReduceDataset"(%arg0, %arg1) {
Targuments = [],
Tstate = [i64], device = "/job:localhost/replica:0/task:0/device:TPU:1",
f = @__reduce_func_1, f._tf_data_function = true,
output_shapes = [#tf_type.shape<>],
output_types = [i64], use_inter_op_parallelism = true, _xla_compile_device_type="TPU"} : (tensor<!tf_type.variant>, tensor<i64>) -> (tensor<i64>)
func.return
}