NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Failure to find the right loop promotion with rfactor domains

naoyam opened this issue · comments

Finding the right promotion of a loop group may fail when rfactor domains are involved.

 auto tv0 = makeConcreteTensor({5});
  fusion.addInput(tv0);
  auto tv1 = makeConcreteTensor({5, 2});
  fusion.addInput(tv1);

  auto tv2 = set(tv0);
  auto tv3 = broadcast(tv2, {false, true});
  auto tv4 = add(tv3, tv1);
  auto tv5 = reshape(tv4, {5, 2}, {10});
  fusion.addOutput(tv5);

  tv4->merge(0);
  tv3->merge(0);

  inlineMost();
TransformPrinter :
T0_g[ iS0{5} ]
 root domain : (iS0{5})
 contiguity: f
 leaf domain : (iS0{5})
T2_l[ iS3{5} ] ca_pos( 1 )
 root domain : (iS3{5})
 contiguity: t
 leaf domain : (iS3{5})
T3_l[ iS14{( 5 * 1 )} ] ca_pos( 1 ) produce_pos( 1 )
 root domain : (iS4{5}, bS5{1})
 contiguity: t n
  Merge: iS4{5} and bS5{1} -> iS14{( 5 * 1 )}
 leaf domain : (iS14{( 5 * 1 )})
T1_g[ iS1{5}, iS2{2} ]
 root domain : (iS1{5}, iS2{2})
 contiguity: f f
 leaf domain : (iS1{5}, iS2{2})
T4_l[ iS13{( 5 * 2 )} ] ca_pos( 1 ) produce_pos( 1 )
 root domain : (iS6{5}, iS7{2})
 contiguity: t t
  Merge: iS6{5} and iS7{2} -> iS13{( 5 * 2 )}
 leaf domain : (iS13{( 5 * 2 )})
T5_g[ iS12{( 5 * 2 )}rf ] ca_pos( 1 ) produce_pos( 1 )
 root domain : (iS10{5}rf, iS11{2}rf)
  Merge: iS10{5}rf and iS11{2}rf -> iS12{( 5 * 2 )}rf
 rfactor domain : (iS12{( 5 * 2 )}rf)
 contiguity: t
 leaf domain : (iS12{( 5 * 2 )}rf)
}

Here's what the transformation looks like:

transform
(The colors indicate exact groupings)

There's only a single loop group in this case: idg{3 4 5 6 7 10 11 12 13 14}

Clearly, 12 is the right promotion domain, but the current logic fails to find it because when that loop group is looked at by findPromotionOfLoopGroup, none of the domains is identified as a representative of all the domains in the loop group. The primary reason is because computeCoveredGroups stops traversal at view rfactor domains, so the covered groups of 12 is just its exact group, i.e., idg{12, 13}. However, since we enumerate the covered exact groups of all of the domains of the loop group, which includes the covered groups of domains such as 14, whose covered groups include idg{0, 1, 3, 4, 6, 10}, the true promotion domain, 12, fails to "cover" all the domains in this loop group.

The idea behind the processing of view rfactor domains in computeCoveredGroups is that since when we do reshape transformations, all of the broadcast domains in the input tensor are first squeezed, so as long as a view rfactor domain is covered, all of the domains proceeding the rfactor domain should also be covered. While that is true, this is not enough for findPromotionOfLoopGroup in a case like the above. More specifically, the exact group of 12 does cover 14, but computeCoveredGroups doesn't convey that information to findPromotionOfLoopGroups.

A straightforward fix would be not stopping the analysis at view rfactor domains.

Another example:

 auto tv0 = makeConcreteTensor({5});
  fusion.addInput(tv0);
  auto tv1 = makeConcreteTensor({5, 2});
  fusion.addInput(tv1);
  auto tv2 = makeConcreteTensor({10, 3});
  fusion.addInput(tv2);

  auto tv3 = set(tv0);
  auto tv4 = broadcast(tv3, {false, true});
  auto tv5 = add(tv4, tv1);
  auto tv6 = reshape(tv5, {5, 2}, {10});
  auto tv7 = broadcast(tv6, {false, true});
  auto tv8 = add(tv7, tv2);
  fusion.addOutput(tv8);

  tv4->merge(0);
  tv5->merge(0);
  tv8->merge(0);
  tv7->merge(0);

  inlineMost();
TransformPrinter :
T0_g[ iS0{5} ]
 root domain : (iS0{5})
 contiguity: f
 leaf domain : (iS0{5})
T3_l[ iS5{5} ] ca_pos( 1 )
 root domain : (iS5{5})
 contiguity: t
 leaf domain : (iS5{5})
T4_l[ iS19{( 5 * 1 )} ] ca_pos( 1 ) produce_pos( 1 )
 root domain : (iS6{5}, bS7{1})
 contiguity: t n
  Merge: iS6{5} and bS7{1} -> iS19{( 5 * 1 )}
 leaf domain : (iS19{( 5 * 1 )})
T1_g[ iS1{5}, iS2{2} ]
 root domain : (iS1{5}, iS2{2})
 contiguity: f f
 leaf domain : (iS1{5}, iS2{2})
T5_l[ iS20{( 5 * 2 )} ] ca_pos( 1 ) produce_pos( 1 )
 root domain : (iS8{5}, iS9{2})
 contiguity: t t
  Merge: iS8{5} and iS9{2} -> iS20{( 5 * 2 )}
 leaf domain : (iS20{( 5 * 2 )})
T6_l[ iS14{( 5 * 2 )}rf ] ca_pos( 1 ) produce_pos( 1 )
 root domain : (iS12{5}rf, iS13{2}rf)
  Merge: iS12{5}rf and iS13{2}rf -> iS14{( 5 * 2 )}rf
 rfactor domain : (iS14{( 5 * 2 )}rf)
 contiguity: t
 leaf domain : (iS14{( 5 * 2 )}rf)
T7_l[ iS22{( ( 5 * 2 ) * 1 )} ] ca_pos( 1 ) produce_pos( 1 )
 root domain : (iS15{( 5 * 2 )}, bS16{1})
 contiguity: t n
  Merge: iS15{( 5 * 2 )} and bS16{1} -> iS22{( ( 5 * 2 ) * 1 )}
 leaf domain : (iS22{( ( 5 * 2 ) * 1 )})
T2_g[ iS3{10}, iS4{3} ]
 root domain : (iS3{10}, iS4{3})
 contiguity: f f
 leaf domain : (iS3{10}, iS4{3})
T8_g[ iS21{( 10 * 3 )} ] ca_pos( 1 ) produce_pos( 1 )
 root domain : (iS17{10}, iS18{3})
 contiguity: t t
  Merge: iS17{10} and iS18{3} -> iS21{( 10 * 3 )}
 leaf domain : (iS21{( 10 * 3 )})
}

transform
(The colors indicate exact groupings)

Again, all domains are mapped together, creating a sole loop group of: idg{5 6 7 8 9 12 13 14 15 16 17 18 19 20 21 22}. The correct promotion should be 21, however, due to the same issue as the first example, none is identified as a promotion of this group.