RuntimeError: producer->getMemoryType() == MemoryType::Global
wujingyue opened this issue · comments
This blocks Lightning-AI/lightning-thunder#191.
nvFuser reproducer
Check out https://github.com/NVIDIA/Fuser/tree/wjy/memory and run python repro.py
.
import torch
from nvfuser import FusionDefinition, DataType
def nvfuser_fusion_id9(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, None], dtype=DataType.Bool, is_cpu=False, stride_order=[2, 1, 0])
T1 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
T2 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[1, 0])
S3 = fd.define_scalar(16, dtype=DataType.Int)
S4 = fd.define_scalar(16, dtype=DataType.Int)
S5 = fd.define_scalar(1, dtype=DataType.Int)
V6 = fd.define_vector([S3, S4, S5], dtype=DataType.Int)
T7 = fd.ops.broadcast_in_dim(T2, shape=V6, broadcast_dims=[0, 1])
S8 = fd.define_scalar(1, dtype=DataType.Int)
S9 = fd.define_scalar(16, dtype=DataType.Int)
S10 = fd.define_scalar(1, dtype=DataType.Int)
S11 = fd.define_scalar(16, dtype=DataType.Int)
S12 = fd.define_scalar(32, dtype=DataType.Int)
S13 = fd.define_scalar(1, dtype=DataType.Int)
V14 = fd.define_vector([S8, S9, S10, S11, S12, S13], dtype=DataType.Int)
T15 = fd.ops.broadcast_in_dim(T7, shape=V14, broadcast_dims=[1, 3, 5])
S16 = fd.define_scalar(16, dtype=DataType.Int)
S17 = fd.define_scalar(16, dtype=DataType.Int)
S18 = fd.define_scalar(32, dtype=DataType.Int)
V19 = fd.define_vector([S16, S17, S18], dtype=DataType.Int)
T20 = fd.ops.reshape(T15, new_shape=V19)
T21 = fd.ops.slice(T1, start_indices=[0, 0, 16], end_indices=[16, 16, 32], strides=[1, 1, 1])
S22 = fd.define_scalar(0.00000, dtype=DataType.Double)
T23 = fd.ops.where(T20, S22, T1)
S24 = fd.define_scalar(0.00000, dtype=DataType.Double)
T25 = fd.ops.where(T0, S24, T23)
fd.add_output(T21)
fd.add_output(T25)
with FusionDefinition() as fd:
nvfuser_fusion_id9(fd)
inputs = [
torch.randint(0, 2, (256,), dtype=torch.bool, device='cuda:1').as_strided((16, 16, 32), (16, 1, 0)),
torch.randn((8192,), dtype=torch.float32, device='cuda:1').as_strided((16, 16, 32), (512, 32, 1)),
torch.randint(0, 2, (256,), dtype=torch.bool, device='cuda:1').as_strided((16, 16), (16, 1)),
]
fd.execute(inputs)
Traceback (most recent call last):
File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 145, in execute
result = self._execute(
RuntimeError: producer->getMemoryType() == MemoryType::Global INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/device_lower/analysis/sync_information.cpp":699, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Inconsistent parallelization found between TV12 (T12_l[ iblockIdx.x180{( ceilDiv(( ceilDiv(( ceilDiv(( 16 * 16 ), 128) ), 1) ), 1) )}, iUS181{1}, iS179{1}, ithreadIdx.x177{128} ]) and TV3(T3_l[ iblockIdx.x173{( ceilDiv(( ceilDiv(( ceilDiv(( 16 * 16 ), 128) ), 1) ), 1) )}, iUS174{1}, iS172{1}, ithreadIdx.x170{128}, bS10{1} ] ca_pos( 4 )). Producer is required to be in Global Memory based on parallelization strategy. RAW flags: (blockIdx.x threadIdx.x)
Exception raised from SyncMap at /opt/pytorch/nvfuser/csrc/device_lower/analysis/sync_information.cpp:699 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x7742424c9149 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x53 (0x7742427fc8a3 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #2: nvfuser::SyncMap::SyncMap(nvfuser::Fusion*) + 0x24a4 (0x77424270e0c4 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x37f560 (0x77424271e560 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #4: nvfuser::GpuLower::GpuLower(nvfuser::Fusion*, nvfuser::CompileParams const&) + 0xca8 (0x77424271fb48 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #5: nvfuser::FusionExecutor::compileFusion(nvfuser::Fusion*, nvfuser::KernelArgumentHolder const&, nvfuser::LaunchParams const&, nvfuser::CompileParams, nvfuser::ScheduleHeuristic, long, long, long, long) + 0x3ed (0x77424281417d in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x65dd42 (0x7742429fcd42 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #7: nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder) + 0x492 (0x774242a05fb2 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #8: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0xad3 (0x774242a10eb3 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #9: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, std::optional<signed char>, bool, bool, bool) const + 0x3ec (0x774242c0264c in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #10: <unknown function> + 0x1a42ce (0x7742425432ce in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #11: <unknown function> + 0x21ae3f (0x7742425b9e3f in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #12: <unknown function> + 0x2aeb60 (0x77424264db60 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
<omitting python frames>
frame #28: <unknown function> + 0x29d90 (0x774351ea1d90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #29: __libc_start_main + 0x80 (0x774351ea1e40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
Thunder reproducer
Check out https://github.com/Lightning-AI/lightning-thunder/tree/wjy/bookend and run python thunder/tests/distributed/test_tensor_parallel.py -k TensorParallelTest.test_both_column_and_row_bias_True
This is interesting. Here's a slightly smaller repro that removes the reshape which is just a squeeze:
def nvfuser_fusion_id9(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, None], dtype=DataType.Bool, is_cpu=False, stride_order=[2, 1, 0])
T1 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
T2 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[1, 0])
S9 = fd.define_scalar(16, dtype=DataType.Int)
S11 = fd.define_scalar(16, dtype=DataType.Int)
S12 = fd.define_scalar(32, dtype=DataType.Int)
V14 = fd.define_vector([S9, S11, S12], dtype=DataType.Int)
T15 = fd.ops.broadcast_in_dim(T2, shape=V14, broadcast_dims=[0, 1])
T21 = fd.ops.slice(T1, start_indices=[0, 0, 16], end_indices=[16, 16, 32], strides=[1, 1, 1])
S22 = fd.define_scalar(0.00000, dtype=DataType.Double)
T23 = fd.ops.where(T15, S22, T1)
S24 = fd.define_scalar(0.00000, dtype=DataType.Double)
T25 = fd.ops.where(T0, S24, T23)
fd.add_output(T21)
fd.add_output(T25)
Notice that the T21 output is a slice on the input, which can be computed as a metadata operation on its own. Unfortunately it looks like the pointwise scheduler is actually using that as the reference tensor where it should be using one of the other tensors. For example I tried forcing the reference tv to T25 and this avoided the error.
The reference tensor for pointwise scheduler is an output tensor. Both T21 and T25 are considered valid references since their output axes map to all of the input tensors' axes using the PERMISSIVE map. I am not sure why T1 causes an invalid schedule, or what the fix would be.
Thanks for looking into this, @jacobhinkle .
I see two solutions so far:
- I'll try to complete
Fuser/csrc/scheduler/pointwise.cpp
Lines 920 to 924 in df79694
- Figure out why T1 causes an invalid schedule and fix the root cause. However, I'll defer this work to someone who's more familiar with the pointwise scheduler than me.
Assigning to Jingyue based on his comment that he'll at least get a workaround here.
Jingyue: feel free to duplicate for just the workaround if you'd rather not close this until your "2" item is fixed.
(Brain dump before I forget)
Fuser/csrc/preseg_passes/mark_aliases_prepare.cpp
Lines 75 to 105 in 1416d89
Fuser/tests/cpp/test_alias.cpp
Line 853 in 1416d89
Likely, I'll slightly extend what's there today to fix this bug. But the below is worth considering for things down the road, e.g., #2375. cc @liqiangxl
Goal
Segment out shape ops from a fusion so the remaining, fewer ops are hopefully easier to schedule.
Difficulty
Doing this carelessly would increase the number of kernels and/or the size of I/O.
Example 1
t0 = Pointwise(in)
t1 = Permute(t0)
out = Pointwise(t1)
Permute
is a meta-op but shouldn't be segmented out. That would lead to two pointwise kernels and much more global reads/writes.
Example 2
t0 = NonMeta(in)
out0 = Meta(t0)
t1 = Meta(t0)
out1 = Meta(t1)
out2 = NonMeta(t1)
Segmenting out t1
's definition would break the fusion into two kernels.
Example 3
t0 = Pointwise(in)
out = Slice(t0)
I understand nvFuser currently enforces a segment boundary before slice
for practical limitations. However, ideally, the two ought to be in the same segment so we only need to Pointwise
the portion that's in the slice.
Example 4
out0 = NonMeta(in)
t0 = Meta0(out0)
out1 = Meta1(t0)
out2 = NonMeta(t0)
Ideally, we'd like to compile it to the following two segments, one of which is meta-only and the other a kernel with two outputs. Note Meta0 is cloned.
// A kernel
out0 = NonMeta(in)
t0 = Meta0(out0)
out2 = NonMeta(t0)
// A meta-only segment
t1 = Meta0(out0)
out1 = Meta1(Meta0(out0))
Proposal
The algorithm below to segment meta ops tries to minimize three things in the following priority order:
- The number of kernels. Note that meta-op only segments are not kernels, so having more of them alone is OK.
- The size of I/O.
- Number of ops in kernels, so fewer ops need to go through schedulers and codegen.
Steps:
- Run alias analysis as we do today.
- Mark all ops that are used for computing a non-alias output, i.e., that are on the path from an input to a non-alias output.
- For each alias output, traverse up until a
slice
or aused-for-non-alias
op and put asegment_set
where the traversal stopped. - For each input, traverse down until a non-meta op or a fork and place a
segment_set
where the traversal stopped. We stop at a fork to minimize the number (and thus the size) of kernel inputs.
Limitations
The proposal doesn't consider the cloning opportunity as mentioned in Example 4. I'm all ears for a better solution.