NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RuntimeError: producer->getMemoryType() == MemoryType::Global

wujingyue opened this issue · comments

This blocks Lightning-AI/lightning-thunder#191.

nvFuser reproducer

Check out https://github.com/NVIDIA/Fuser/tree/wjy/memory and run python repro.py.

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id9(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, None], dtype=DataType.Bool, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T2 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[1, 0])
    S3 = fd.define_scalar(16, dtype=DataType.Int)
    S4 = fd.define_scalar(16, dtype=DataType.Int)
    S5 = fd.define_scalar(1, dtype=DataType.Int)
    V6 = fd.define_vector([S3, S4, S5], dtype=DataType.Int)
    T7 = fd.ops.broadcast_in_dim(T2, shape=V6, broadcast_dims=[0, 1])
    S8 = fd.define_scalar(1, dtype=DataType.Int)
    S9 = fd.define_scalar(16, dtype=DataType.Int)
    S10 = fd.define_scalar(1, dtype=DataType.Int)
    S11 = fd.define_scalar(16, dtype=DataType.Int)
    S12 = fd.define_scalar(32, dtype=DataType.Int)
    S13 = fd.define_scalar(1, dtype=DataType.Int)
    V14 = fd.define_vector([S8, S9, S10, S11, S12, S13], dtype=DataType.Int)
    T15 = fd.ops.broadcast_in_dim(T7, shape=V14, broadcast_dims=[1, 3, 5])
    S16 = fd.define_scalar(16, dtype=DataType.Int)
    S17 = fd.define_scalar(16, dtype=DataType.Int)
    S18 = fd.define_scalar(32, dtype=DataType.Int)
    V19 = fd.define_vector([S16, S17, S18], dtype=DataType.Int)
    T20 = fd.ops.reshape(T15, new_shape=V19)
    T21 = fd.ops.slice(T1, start_indices=[0, 0, 16], end_indices=[16, 16, 32], strides=[1, 1, 1])
    S22 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T23 = fd.ops.where(T20, S22, T1)
    S24 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T25 = fd.ops.where(T0, S24, T23)
    fd.add_output(T21)
    fd.add_output(T25)

with FusionDefinition() as fd:
    nvfuser_fusion_id9(fd)

inputs = [
    torch.randint(0, 2, (256,), dtype=torch.bool, device='cuda:1').as_strided((16, 16, 32), (16, 1, 0)),
    torch.randn((8192,), dtype=torch.float32, device='cuda:1').as_strided((16, 16, 32), (512, 32, 1)),
    torch.randint(0, 2, (256,), dtype=torch.bool, device='cuda:1').as_strided((16, 16), (16, 1)),
]
fd.execute(inputs)
Traceback (most recent call last):
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 145, in execute
    result = self._execute(
RuntimeError: producer->getMemoryType() == MemoryType::Global INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/device_lower/analysis/sync_information.cpp":699, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Inconsistent parallelization found between TV12 (T12_l[ iblockIdx.x180{( ceilDiv(( ceilDiv(( ceilDiv(( 16 * 16 ), 128) ), 1) ), 1) )}, iUS181{1}, iS179{1}, ithreadIdx.x177{128} ]) and TV3(T3_l[ iblockIdx.x173{( ceilDiv(( ceilDiv(( ceilDiv(( 16 * 16 ), 128) ), 1) ), 1) )}, iUS174{1}, iS172{1}, ithreadIdx.x170{128}, bS10{1} ] ca_pos( 4 )). Producer is required to be in Global Memory based on parallelization strategy. RAW flags: (blockIdx.x threadIdx.x)
Exception raised from SyncMap at /opt/pytorch/nvfuser/csrc/device_lower/analysis/sync_information.cpp:699 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x7742424c9149 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x53 (0x7742427fc8a3 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #2: nvfuser::SyncMap::SyncMap(nvfuser::Fusion*) + 0x24a4 (0x77424270e0c4 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x37f560 (0x77424271e560 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #4: nvfuser::GpuLower::GpuLower(nvfuser::Fusion*, nvfuser::CompileParams const&) + 0xca8 (0x77424271fb48 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #5: nvfuser::FusionExecutor::compileFusion(nvfuser::Fusion*, nvfuser::KernelArgumentHolder const&, nvfuser::LaunchParams const&, nvfuser::CompileParams, nvfuser::ScheduleHeuristic, long, long, long, long) + 0x3ed (0x77424281417d in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x65dd42 (0x7742429fcd42 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #7: nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder) + 0x492 (0x774242a05fb2 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #8: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0xad3 (0x774242a10eb3 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #9: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, std::optional<signed char>, bool, bool, bool) const + 0x3ec (0x774242c0264c in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #10: <unknown function> + 0x1a42ce (0x7742425432ce in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #11: <unknown function> + 0x21ae3f (0x7742425b9e3f in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #12: <unknown function> + 0x2aeb60 (0x77424264db60 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
<omitting python frames>
frame #28: <unknown function> + 0x29d90 (0x774351ea1d90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #29: __libc_start_main + 0x80 (0x774351ea1e40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Thunder reproducer

Check out https://github.com/Lightning-AI/lightning-thunder/tree/wjy/bookend and run python thunder/tests/distributed/test_tensor_parallel.py -k TensorParallelTest.test_both_column_and_row_bias_True

This is interesting. Here's a slightly smaller repro that removes the reshape which is just a squeeze:

def nvfuser_fusion_id9(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, None], dtype=DataType.Bool, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T2 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[1, 0])
    S9 = fd.define_scalar(16, dtype=DataType.Int)
    S11 = fd.define_scalar(16, dtype=DataType.Int)
    S12 = fd.define_scalar(32, dtype=DataType.Int)
    V14 = fd.define_vector([S9, S11, S12], dtype=DataType.Int)
    T15 = fd.ops.broadcast_in_dim(T2, shape=V14, broadcast_dims=[0, 1])
    T21 = fd.ops.slice(T1, start_indices=[0, 0, 16], end_indices=[16, 16, 32], strides=[1, 1, 1])
    S22 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T23 = fd.ops.where(T15, S22, T1)
    S24 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T25 = fd.ops.where(T0, S24, T23)
    fd.add_output(T21)
    fd.add_output(T25)

Notice that the T21 output is a slice on the input, which can be computed as a metadata operation on its own. Unfortunately it looks like the pointwise scheduler is actually using that as the reference tensor where it should be using one of the other tensors. For example I tried forcing the reference tv to T25 and this avoided the error.

The reference tensor for pointwise scheduler is an output tensor. Both T21 and T25 are considered valid references since their output axes map to all of the input tensors' axes using the PERMISSIVE map. I am not sure why T1 causes an invalid schedule, or what the fix would be.

Thanks for looking into this, @jacobhinkle .

I see two solutions so far:

  1. I'll try to complete
    // TODO(#1401): We could let segmentation split a partially alias-producing
    // fusion into an alias-only segment and the rest. This way, the rest of the
    // fusion (which has fewer expressions) can potentially find a better
    // scheduler and we need to call markAliases only in NoOpScheduler.
    markAliases(fusion);
    , which I think will work around this particular bug and make the pointwise scheduler more effective in general (as it has a smaller fusion to schedule).
  2. Figure out why T1 causes an invalid schedule and fix the root cause. However, I'll defer this work to someone who's more familiar with the pointwise scheduler than me.

Assigning to Jingyue based on his comment that he'll at least get a workaround here.

Jingyue: feel free to duplicate for just the workaround if you'd rather not close this until your "2" item is fixed.

(Brain dump before I forget)

// This is suboptimal in many uncommon cases. My head hurts when thinking
// about them, so go simple for now :)
//
// Legend:
// M* = a meta op defining a **fusion output**
// N/M = a non-meta op defining a **fusion output**
//
// Case 1:
//
// N/M -> N/M
// |
// --> M
//
// We should put a `segment_set` on the **edge** from N/M to M, so the two
// `N/M`s go to the same kernel.
//
// Case 2:
//
// N/M -> M1 -> M2
// |
// --> N/M
//
// We should change it to
//
// N/M -> M1 -> M2
// |
// --> M1' (non-output copy of M1) -> N/M
//
// and then put a `segment_set` on N/M->M1.
aliased_out->cacheBefore(LoadStoreOpType::SegmenterSet);
}
was a simple solution to segment out meta ops from non-meta ops. It was just enough for the use cases at that time -- mostly
TEST_F(AliasTest, OutputAliasesAnotherOutput) {
. With more use cases coming up (this bug and #2375), the solution no longer works well and I'm trying to revisit that.

Likely, I'll slightly extend what's there today to fix this bug. But the below is worth considering for things down the road, e.g., #2375. cc @liqiangxl

Goal

Segment out shape ops from a fusion so the remaining, fewer ops are hopefully easier to schedule.

Difficulty

Doing this carelessly would increase the number of kernels and/or the size of I/O.

Example 1

t0 = Pointwise(in)
t1 = Permute(t0)
out = Pointwise(t1)

Permute is a meta-op but shouldn't be segmented out. That would lead to two pointwise kernels and much more global reads/writes.

Example 2

t0 = NonMeta(in)
out0 = Meta(t0)
t1 = Meta(t0)
out1 = Meta(t1)
out2 = NonMeta(t1)

Segmenting out t1's definition would break the fusion into two kernels.

Example 3

t0 = Pointwise(in)
out = Slice(t0)

I understand nvFuser currently enforces a segment boundary before slice for practical limitations. However, ideally, the two ought to be in the same segment so we only need to Pointwise the portion that's in the slice.

Example 4

out0 = NonMeta(in)
t0 = Meta0(out0)
out1 = Meta1(t0)
out2 = NonMeta(t0)

Ideally, we'd like to compile it to the following two segments, one of which is meta-only and the other a kernel with two outputs. Note Meta0 is cloned.

// A kernel
out0 = NonMeta(in)
t0 = Meta0(out0)
out2 = NonMeta(t0)
// A meta-only segment
t1 = Meta0(out0)
out1 = Meta1(Meta0(out0))

Proposal

The algorithm below to segment meta ops tries to minimize three things in the following priority order:

  1. The number of kernels. Note that meta-op only segments are not kernels, so having more of them alone is OK.
  2. The size of I/O.
  3. Number of ops in kernels, so fewer ops need to go through schedulers and codegen.

Steps:

  1. Run alias analysis as we do today.
  2. Mark all ops that are used for computing a non-alias output, i.e., that are on the path from an input to a non-alias output.
  3. For each alias output, traverse up until a slice or a used-for-non-alias op and put a segment_set where the traversal stopped.
  4. For each input, traverse down until a non-meta op or a fork and place a segment_set where the traversal stopped. We stop at a fork to minimize the number (and thus the size) of kernel inputs.

Limitations

The proposal doesn't consider the cloning opportunity as mentioned in Example 4. I'm all ears for a better solution.