pad propagation and replay issues.
jjsjann123 opened this issue · comments
Mostly just an issue for myself to record progress while working on #1597 :
On the preseg transformation branch, I'm applying a transformation that pushes pad out and replaces CatOp
with a binary add.
Repro branch here #2373
repro script:
from nvfuser import FusionDefinition, DataType
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
T1 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
S2 = fd.define_scalar(0.297302, dtype=DataType.Double)
#T3 = fd.ops.mul(T1, S2)
T3 = fd.ops.relu(T1)
T4 = fd.ops.cat([T0, T3], -1)
fd.add_output(T4)
with FusionDefinition() as fd:
nvfuser_fusion_id0(fd)
inputs = [
torch.randn((100,), dtype=torch.float32, device='cuda:0').as_strided((2, 5, 10), (50, 10, 1)),
torch.randn((30,), dtype=torch.float32, device='cuda:0').as_strided((2, 5, 3), (15, 3, 1)),
]
fd.execute(inputs)
Here's the fusion IR before the transformation
Inputs:
T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ], float
T1_g[ iS3{i4}, iS4{i5}, iS5{i6} ], float
Outputs:
T5_g[ iS17{i0}, iS18{i1}, iS22{( i2 + i6 )} ], float
%kernel_math {
i15 = i2 + i6;
i17 = -i2;
i19 = i15 + i17;
T3_l[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ]
= pad( T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ], {0, 0, 0, 0, 0, i19} )
T2_l[ iS6{i4}, iS7{i5}, iS8{i6} ]
= relu(T1_g[ iS3{i4}, iS4{i5}, iS5{i6} ]);
i21 = 0 + i2;
T4_l[ iS13{i4}, iS14{i5}, iS21{( i6 + ( 0 + i2 ) )}rf ]
= pad( T2_l[ iS6{i4}, iS7{i5}, iS8{i6} ], {0, 0, 0, 0, i21, 0} )
T5_g[ iS17{i0}, iS18{i1}, iS22{( i2 + i6 )} ]
= cat( T3_l[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ], T4_l[ iS13{i4}, iS14{i5}, iS21{( i6 + ( 0 + i2 ) )}rf ], 2 )
}
Here's the fusion IR after the transformation. I have:
- moved the
pad
at the beginning of the fusion. - replayed
relu
with the padded input. - replace the
cat
with aadd
, as well as the one of the input from thepad
with the newrelu
output.
Inputs:
T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ], float
T1_g[ iS3{i4}, iS4{i5}, iS5{i6} ], float
Outputs:
T8_g[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ], float
%kernel_math {
i15 = i2 + i6;
i17 = -i2;
i19 = i15 + i17;
T3_l[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ]
= pad( T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ], {0, 0, 0, 0, 0, i19} )
i21 = 0 + i2;
T6_l[ iS13{i4}, iS14{i5}, iS21{( i6 + ( 0 + i2 ) )}rf ]
= pad( T1_g[ iS3{i4}, iS4{i5}, iS5{i6} ], {0, 0, 0, 0, i21, 0} )
T7_l[ iS23{i4}, iS24{i5}, iS25{( i6 + ( 0 + i2 ) )} ]
= relu(T6_l[ iS13{i4}, iS14{i5}, iS21{( i6 + ( 0 + i2 ) )}rf ]);
T8_g[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ]
= T3_l[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ]
+ T7_l[ iS23{i4}, iS24{i5}, iS25{( i6 + ( 0 + i2 ) )} ];
}
I'm hitting issues with compute_at_map
here:
Traceback (most recent call last):
File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 145, in execute
result = self._execute(
RuntimeError: logical_id_uses.find(logical_inp_id) == logical_id_uses.end() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/compute_at_map.cpp":621, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Was expecting iter domains to only have one active transformation but found id iS11{i2}rf used in
Resize: iS11{i2}rf by 0 and ( ( i2 + i6 ) + ( -i2 ) ) -> iS20{( i2 + i6 )}rf
and
Resize: iS11{i2}rf by 0 and ( ( i2 + i6 ) + ( -i2 ) ) -> iS20{( i2 + i6 )}rf
Looks like we are seeing redundant expressions. Maybe there's something in the replay I was using.
More Context
I tried to verify that the transformation is legit. I have a cpp test that generates a similar fusion IR and everything is working fine with the cpp test below.
TEST_F(NVFuserTest, FusionPadDynamicShape) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
TensorView* tv0 = makeContigTensor(3);
fusion->addInput(tv0);
TensorView* tv1 = makeContigTensor(3);
fusion->addInput(tv1);
Val* i15 = add(tv0->axis(2)->extent(), tv1->axis(2)->extent());
Val* i17 = neg(tv0->axis(2)->extent());
Val* i19 = add(i15, i17);
Val* zero = IrBuilder::create<Val>(0);
TensorView* tv3 = pad(tv0, {zero, i19, zero, zero, zero, zero});
Val* i21 = add(zero, tv0->axis(2)->extent());
TensorView* tv6 = pad(tv1, {i21, zero, zero, zero, zero, zero});
TensorView* tv7 = relu(tv6);
TensorView* tv8 = add(tv3, tv7);
fusion->addOutput(tv8);
fusion->printMath();
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({2, 5, 10}, options);
at::Tensor t1 = at::randn({2, 5, 3}, options);
std::vector<c10::IValue> aten_inputs({t0, t1});
FusionExecutorCache executor_cache(std::move(fusion));
auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);
testValidate(
executor_cache.fusion(), cg_outputs, {t0, t1}, __LINE__, __FILE__);
}
Running with NVFUSER_DUMP=segmented_fusion
Segmented_Fusion Dump: -- Re-written complete fusion:{
Inputs:
T0_g[ iS0{i0}, iS1{i2}, iS2{i3} ], float
T1_g[ iS24{i0}, iS25{i2}, iS5{i6} ], float
Outputs:
T5_g[ iS17{i0}, iS18{i2}, iS23{( i3 + i6 )} ], float
%kernel_math {
i7 = i3 + i6;
i9 = -i3;
i11 = i7 + i9;
T2_l[ iS6{i0}, iS7{i2}, iS20{( i3 + i6 )}rf ]
= pad( T0_g[ iS0{i0}, iS1{i2}, iS2{i3} ], {0, 0, 0, 0, 0, i11} )
i20 = 0 + i3;
T3_l[ iS26{i0}, iS27{i2}, iS21{( i6 + ( 0 + i3 ) )}rf ]
= pad( T1_g[ iS24{i0}, iS25{i2}, iS5{i6} ], {0, 0, 0, 0, i20, 0} )
T4_l[ iS28{i0}, iS29{i2}, iS22{( i6 + ( 0 + i3 ) )} ]
= relu(T3_l[ iS26{i0}, iS27{i2}, iS21{( i6 + ( 0 + i3 ) )}rf ]);
T5_g[ iS17{i0}, iS18{i2}, iS23{( i3 + i6 )} ]
= T2_l[ iS6{i0}, iS7{i2}, iS20{( i3 + i6 )}rf ]
+ T4_l[ iS28{i0}, iS29{i2}, iS22{( i6 + ( 0 + i3 ) )} ];
}
} // {Re-written complete fusion}
The compute_at issue looks like it's coming from idmodel. I'll get a c++ repro and open a separate issue for that. But for the time being, I think I should just catch the exception and brute force the prototype first.
Thread 1 "pt_main_thread" hit Catchpoint 1 (exception thrown), 0x00007ffff77fe4a1 in __cxa_throw () from /lib/x86_64-linux-gnu/libstdc++.so.6
(gdb) bt
#0 0x00007ffff77fe4a1 in __cxa_throw () from /lib/x86_64-linux-gnu/libstdc++.so.6
#1 0x00007ffdb76b86f2 in nvfuser::nvfCheckFail (func=0x7ffdb7eef705 "build", file=0x7ffdb7eef048 "/opt/pytorch/nvfuser/csrc/compute_at_map.cpp", line=621,
msg="logical_id_uses.find(logical_inp_id) == logical_id_uses.end() INTERNAL ASSERT FAILED at \"/opt/pytorch/nvfuser/csrc/compute_at_map.cpp\":621, please report a bug with repro script to NVFuser at https://
"...) at /opt/pytorch/nvfuser/csrc/exceptions.cpp:274
#2 0x00007ffdb76b8907 in nvfuser::nvfErrorFail (func=0x7ffdb7eef705 "build", file=0x7ffdb7eef048 "/opt/pytorch/nvfuser/csrc/compute_at_map.cpp", line=621,
condMsg=0x7ffdb7eef8e8 "logical_id_uses.find(logical_inp_id) == logical_id_uses.end() INTERNAL ASSERT FAILED at \"/opt/pytorch/nvfuser/csrc/compute_at_map.cpp\":621, please report a bug with repro script to
NVFuser at https://"...,
userMsg="Was expecting iter domains to only have one active transformation but found id iS11{i2}rf used in\nResize: iS11{i2}rf by 0 and ( ( i2 + i6 ) + ( -i2 ) ) -> iS20{( i2 + i6 )}rf\n\nand\nResize: iS11{i
2}rf b"...) at /opt/pytorch/nvfuser/csrc/exceptions.cpp:300
#3 0x00007ffdb73bf4dc in nvfuser::IterDomainGraph::build (this=0x7fffffffbbf0, fusion=0x7ffda8774e80) at /opt/pytorch/nvfuser/csrc/compute_at_map.cpp:621
#4 0x00007ffdb73bbb66 in nvfuser::IterDomainGraph::IterDomainGraph (this=0x7fffffffbbf0, fusion=0x7ffda8774e80, allow_self_mapping=true) at /opt/pytorch/nvfuser/csrc/compute_at_map.cpp:45
#5 0x00007ffdb7cc5b69 in nvfuser::(anonymous namespace)::checkCanSchedule<nvfuser::ExprEvalScheduler> (fusion=0x7ffda8774e80, runtime_info=..., data_cache=0x0)
at /opt/pytorch/nvfuser/csrc/scheduler/registry.cpp:185
#6 0x00007ffdb7cc37e0 in nvfuser::SchedulerEntry::canSchedule (sh=nvfuser::ScheduleHeuristic::ExprEval, fusion=0x7ffda8774e80, runtime_info=..., data_cache=0x0)
at /opt/pytorch/nvfuser/csrc/scheduler/registry.cpp:231
#7 0x00007ffdb7cc3c80 in nvfuser::SchedulerEntry::proposeHeuristics (fusion=0x7ffda8774e80, runtime_info=...) at /opt/pytorch/nvfuser/csrc/scheduler/registry.cpp:295
(gdb) frame 5
#5 0x00007ffdb7cc5b69 in nvfuser::(anonymous namespace)::checkCanSchedule<nvfuser::ExprEvalScheduler> (fusion=0x7ffda8774e80, runtime_info=..., data_cache=0x0)
at /opt/pytorch/nvfuser/csrc/scheduler/registry.cpp:185
185 if (IterDomainGraph(fusion, /*allow_self_mapping=*/true).hasSelfMapping()) {
(gdb) print fusion->printMath(0)
Inputs:
T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ], float
T1_g[ iS26{i0}, iS27{i1}, iS5{i6} ], float
Outputs:
T8_g[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ], float
%kernel_math {
T2_l[ iS6{i4}, iS7{i5}, iS8{i6} ]
= relu(T1_g[ iS26{i0}, iS27{i1}, iS5{i6} ]);
i21 = 0 + i2;
i29 = i21 + i6;
i37 = i6 + i21;
T4_l[ iS13{i4}, iS14{i5}, iS21{( i6 + ( 0 + i2 ) )}rf ]
= pad( T2_l[ iS6{i4}, iS7{i5}, iS8{i6} ], {0, 0, 0, 0, i21, 0} )
i15 = i2 + i6;
i17 = -i2;
i19 = i15 + i17;
T3_l[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ]
= pad( T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ], {0, 0, 0, 0, 0, i19} )
T5_l[ iS17{i0}, iS18{i1}, iS22{( i2 + i6 )} ]
= cat( T3_l[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ], T4_l[ iS13{i4}, iS14{i5}, iS21{( i6 + ( 0 + i2 ) )}rf ], 2 )
Resize: iS11{i2}rf by 0 and ( ( i2 + i6 ) + ( -i2 ) ) -> iS20{( i2 + i6 )}rf
i47 = i6 + i21;
Resize: iS15{i6}rf by ( 0 + i2 ) and 0 -> iS21{( i6 + ( 0 + i2 ) )}rf
b49 = blockIdx.x >= 0;
b51 = gridDim.x > 0;
b53 = blockIdx.x < gridDim.x;
b55 = blockIdx.y >= 0;
b57 = gridDim.y > 0;
b59 = blockIdx.y < gridDim.y;
b61 = blockIdx.z >= 0;
b63 = gridDim.z > 0;
b65 = blockIdx.z < gridDim.z;
b67 = threadIdx.x >= 0;
b69 = blockDim.x > 0;
b71 = threadIdx.x < blockDim.x;
b73 = threadIdx.y >= 0;
b75 = blockDim.y > 0;
b77 = threadIdx.y < blockDim.y;
b79 = threadIdx.z >= 0;
b81 = blockDim.z > 0;
b83 = threadIdx.z < blockDim.z;
b85 = i0 > 0;
b87 = i1 > 0;
b89 = i2 > 0;
b91 = i0 > 0;
b93 = i1 > 0;
b95 = i6 > 0;
T6_l[ iS28{i0}, iS29{i1}, iS21{( i6 + ( 0 + i2 ) )}rf ]
= pad( T1_g[ iS26{i0}, iS27{i1}, iS5{i6} ], {0, 0, 0, 0, i21, 0} )
T7_l[ iS30{i0}, iS31{i1}, iS25{( i6 + ( 0 + i2 ) )} ]
= relu(T6_l[ iS28{i0}, iS29{i1}, iS21{( i6 + ( 0 + i2 ) )}rf ]);
T8_g[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ]
= T3_l[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ]
+ T7_l[ iS30{i0}, iS31{i1}, iS25{( i6 + ( 0 + i2 ) )} ];
s98 = getMetaData(T0_g[ iS0{i0}, iS1{i1}, iS2{i2} ])
s99 = getMetaData(T1_g[ iS26{i0}, iS27{i1}, iS5{i6} ])
}
NVM... figured out what went wrong earlier.
T8_g[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ]
= T3_l[ iS9{i0}, iS10{i1}, iS20{( i2 + i6 )}rf ]
+ T7_l[ iS30{i0}, iS31{i1}, iS25{( i6 + ( 0 + i2 ) )} ];
output T8_g shouldn't have the last root id being an rfactor. fixing that resolves the problem... That's a messy bug with lots of noises.....
Note for my self:
- double check how output domain should be propagated from binary add (with regarding to both inputs coming from pad (both padded id is rfactor domain), or in the case above there's only one being rfactor domain.
- double check how multiple
cat
would work out as well. - figure out some test case to cover the propagation restriction in implementation. Unary op propagation rule is pretty straightforward, but binary op will be trickier, maybe I can/should wait until binary op supported is added as well.
closing this since it's functional in my PR #2490