NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Squeezed IterDomain ?S536{1} must concretize to IterType::Broadcast but found ?S536{1}.

wujingyue opened this issue · comments

This happened when I ran the transformer block with batch_size=1. It can be reproduced by

  1. checking out https://github.com/Lightning-AI/lightning-thunder/tree/wjy/sharded, and
  2. running pytest thunder/benchmarks/targets.py -k test_nanogpt_block_grad[thunder] -s.

I'm unsure whether it's a Thunder bug or nvFuser bug. I suspect define_tensor needs to say shape=[1,...] when the batch size is one?

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T3 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T4 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T5 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T6 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T7 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T8 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T9 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T10 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T11 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T12 = fd.define_tensor(shape=[1, -1, -1], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T13 = fd.ops.cast(T12, dtype=DataType.Float)
    T14, T15 = fd.ops.var_mean(T13, dims=[2], correction=0, keepdim=False)
    S16 = fd.define_scalar(1, dtype=DataType.Int)
    S17 = fd.define_scalar(2048, dtype=DataType.Int)
    S18 = fd.define_scalar(1, dtype=DataType.Int)
    V19 = fd.define_vector([S16, S17, S18], dtype=DataType.Int)
    T20 = fd.ops.broadcast_in_dim(T14, shape=V19, broadcast_dims=[0, 1])
    S21 = fd.define_scalar(1, dtype=DataType.Int)
    S22 = fd.define_scalar(2048, dtype=DataType.Int)
    S23 = fd.define_scalar(1, dtype=DataType.Int)
    V24 = fd.define_vector([S21, S22, S23], dtype=DataType.Int)
    T25 = fd.ops.broadcast_in_dim(T15, shape=V24, broadcast_dims=[0, 1])
    S26 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T27 = fd.ops.add(T20, S26)
    T28 = fd.ops.rsqrt(T27)
    S29 = fd.define_scalar(1, dtype=DataType.Int)
    S30 = fd.define_scalar(2048, dtype=DataType.Int)
    S31 = fd.define_scalar(12288, dtype=DataType.Int)
    V32 = fd.define_vector([S29, S30, S31], dtype=DataType.Int)
    T33 = fd.ops.broadcast_in_dim(T25, shape=V32, broadcast_dims=[0, 1, 2])
    T34 = fd.ops.sub(T13, T33)
    S35 = fd.define_scalar(1, dtype=DataType.Int)
    S36 = fd.define_scalar(2048, dtype=DataType.Int)
    S37 = fd.define_scalar(12288, dtype=DataType.Int)
    V38 = fd.define_vector([S35, S36, S37], dtype=DataType.Int)
    T39 = fd.ops.broadcast_in_dim(T28, shape=V38, broadcast_dims=[0, 1, 2])
    T40 = fd.ops.mul(T34, T39)
    S41 = fd.define_scalar(1, dtype=DataType.Int)
    S42 = fd.define_scalar(2048, dtype=DataType.Int)
    S43 = fd.define_scalar(12288, dtype=DataType.Int)
    V44 = fd.define_vector([S41, S42, S43], dtype=DataType.Int)
    T45 = fd.ops.broadcast_in_dim(T5, shape=V44, broadcast_dims=[2])
    T46 = fd.ops.cast(T45, dtype=DataType.Float)
    T47 = fd.ops.mul(T40, T46)
    S48 = fd.define_scalar(1, dtype=DataType.Int)
    S49 = fd.define_scalar(2048, dtype=DataType.Int)
    S50 = fd.define_scalar(12288, dtype=DataType.Int)
    V51 = fd.define_vector([S48, S49, S50], dtype=DataType.Int)
    T52 = fd.ops.broadcast_in_dim(T4, shape=V51, broadcast_dims=[2])
    T53 = fd.ops.cast(T52, dtype=DataType.Float)
    T54 = fd.ops.add(T47, T53)
    T55 = fd.ops.cast(T54, dtype=DataType.BFloat16)
    S56 = fd.define_scalar(2048, dtype=DataType.Int)
    S57 = fd.define_scalar(12288, dtype=DataType.Int)
    V58 = fd.define_vector([S56, S57], dtype=DataType.Int)
    T59 = fd.ops.reshape(T55, new_shape=V58)
    T60 = fd.ops.linear(T59, T1, T0)
    S61 = fd.define_scalar(1, dtype=DataType.Int)
    S62 = fd.define_scalar(2048, dtype=DataType.Int)
    S63 = fd.define_scalar(36864, dtype=DataType.Int)
    V64 = fd.define_vector([S61, S62, S63], dtype=DataType.Int)
    T65 = fd.ops.reshape(T60, new_shape=V64)
    T66 = fd.ops.slice(T65, start_indices=[0, 0, 0], end_indices=[1, 2048, 12288], strides=[1, 1, 1])
    T67 = fd.ops.slice(T65, start_indices=[0, 0, 12288], end_indices=[1, 2048, 24576], strides=[1, 1, 1])
    T68 = fd.ops.slice(T65, start_indices=[0, 0, 24576], end_indices=[1, 2048, 36864], strides=[1, 1, 1])
    S69 = fd.define_scalar(1, dtype=DataType.Int)
    S70 = fd.define_scalar(2048, dtype=DataType.Int)
    S71 = fd.define_scalar(96, dtype=DataType.Int)
    S72 = fd.define_scalar(128, dtype=DataType.Int)
    V73 = fd.define_vector([S69, S70, S71, S72], dtype=DataType.Int)
    T74 = fd.ops.reshape(T67, new_shape=V73)
    T75 = fd.ops.permute(T74, dims=[0, 2, 1, 3])
    S76 = fd.define_scalar(1, dtype=DataType.Int)
    S77 = fd.define_scalar(2048, dtype=DataType.Int)
    S78 = fd.define_scalar(96, dtype=DataType.Int)
    S79 = fd.define_scalar(128, dtype=DataType.Int)
    V80 = fd.define_vector([S76, S77, S78, S79], dtype=DataType.Int)
    T81 = fd.ops.reshape(T66, new_shape=V80)
    T82 = fd.ops.permute(T81, dims=[0, 2, 1, 3])
    S83 = fd.define_scalar(1, dtype=DataType.Int)
    S84 = fd.define_scalar(2048, dtype=DataType.Int)
    S85 = fd.define_scalar(96, dtype=DataType.Int)
    S86 = fd.define_scalar(128, dtype=DataType.Int)
    V87 = fd.define_vector([S83, S84, S85, S86], dtype=DataType.Int)
    T88 = fd.ops.reshape(T68, new_shape=V87)
    T89 = fd.ops.permute(T88, dims=[0, 2, 1, 3])
    T90 = fd.ops.cast(T82, dtype=DataType.Float)
    S91 = fd.define_scalar(0.297302, dtype=DataType.Double)
    T92 = fd.ops.mul(T90, S91)
    T93 = fd.ops.cast(T92, dtype=DataType.BFloat16)
    T94 = fd.ops.permute(T75, dims=[0, 1, 3, 2])
    T95 = fd.ops.cast(T94, dtype=DataType.Float)
    S96 = fd.define_scalar(0.297302, dtype=DataType.Double)
    T97 = fd.ops.mul(T95, S96)
    T98 = fd.ops.cast(T97, dtype=DataType.BFloat16)
    T99 = fd.ops.matmul(T93, T98)
    S100 = fd.define_scalar(2048, dtype=DataType.Int)
    S101 = fd.define_scalar(0, dtype=DataType.Int)
    S102 = fd.define_scalar(1, dtype=DataType.Int)
    T103 = fd.ops.iota(S100, S101, S102, dtype=DataType.Int)
    S104 = fd.define_scalar(2048, dtype=DataType.Int)
    S105 = fd.define_scalar(1, dtype=DataType.Int)
    V106 = fd.define_vector([S104, S105], dtype=DataType.Int)
    T107 = fd.ops.broadcast_in_dim(T103, shape=V106, broadcast_dims=[0])
    S108 = fd.define_scalar(2048, dtype=DataType.Int)
    S109 = fd.define_scalar(0, dtype=DataType.Int)
    S110 = fd.define_scalar(1, dtype=DataType.Int)
    T111 = fd.ops.iota(S108, S109, S110, dtype=DataType.Int)
    S112 = fd.define_scalar(1, dtype=DataType.Int)
    S113 = fd.define_scalar(2048, dtype=DataType.Int)
    V114 = fd.define_vector([S112, S113], dtype=DataType.Int)
    T115 = fd.ops.broadcast_in_dim(T111, shape=V114, broadcast_dims=[1])
    S116 = fd.define_scalar(0, dtype=DataType.Int)
    T117 = fd.ops.add(T107, S116)
    S118 = fd.define_scalar(2048, dtype=DataType.Int)
    S119 = fd.define_scalar(2048, dtype=DataType.Int)
    V120 = fd.define_vector([S118, S119], dtype=DataType.Int)
    T121 = fd.ops.broadcast_in_dim(T117, shape=V120, broadcast_dims=[0, 1])
    S122 = fd.define_scalar(2048, dtype=DataType.Int)
    S123 = fd.define_scalar(2048, dtype=DataType.Int)
    V124 = fd.define_vector([S122, S123], dtype=DataType.Int)
    T125 = fd.ops.broadcast_in_dim(T115, shape=V124, broadcast_dims=[0, 1])
    T126 = fd.ops.ge(T121, T125)
    S127 = fd.define_scalar(1, dtype=DataType.Int)
    S128 = fd.define_scalar(96, dtype=DataType.Int)
    S129 = fd.define_scalar(2048, dtype=DataType.Int)
    S130 = fd.define_scalar(2048, dtype=DataType.Int)
    V131 = fd.define_vector([S127, S128, S129, S130], dtype=DataType.Int)
    T132 = fd.ops.broadcast_in_dim(T126, shape=V131, broadcast_dims=[2, 3])
    S133 = fd.define_scalar(float("-inf"), dtype=DataType.Double)
    T134 = fd.ops.where(T132, T99, S133)
    T135 = fd.ops.cast(T134, dtype=DataType.Float)
    T136 = fd.ops.max(T135, dims=[3], keepdim=False, dtype=DataType.Null)
    S137 = fd.define_scalar(1, dtype=DataType.Int)
    S138 = fd.define_scalar(96, dtype=DataType.Int)
    S139 = fd.define_scalar(2048, dtype=DataType.Int)
    S140 = fd.define_scalar(1, dtype=DataType.Int)
    V141 = fd.define_vector([S137, S138, S139, S140], dtype=DataType.Int)
    T142 = fd.ops.broadcast_in_dim(T136, shape=V141, broadcast_dims=[0, 1, 2])
    S143 = fd.define_scalar(1, dtype=DataType.Int)
    S144 = fd.define_scalar(96, dtype=DataType.Int)
    S145 = fd.define_scalar(2048, dtype=DataType.Int)
    S146 = fd.define_scalar(2048, dtype=DataType.Int)
    V147 = fd.define_vector([S143, S144, S145, S146], dtype=DataType.Int)
    T148 = fd.ops.broadcast_in_dim(T142, shape=V147, broadcast_dims=[0, 1, 2, 3])
    T149 = fd.ops.sub(T135, T148)
    T150 = fd.ops.exp(T149)
    T151 = fd.ops.sum(T150, dims=[3], keepdim=False, dtype=DataType.Null)
    S152 = fd.define_scalar(1, dtype=DataType.Int)
    S153 = fd.define_scalar(96, dtype=DataType.Int)
    S154 = fd.define_scalar(2048, dtype=DataType.Int)
    S155 = fd.define_scalar(1, dtype=DataType.Int)
    V156 = fd.define_vector([S152, S153, S154, S155], dtype=DataType.Int)
    T157 = fd.ops.broadcast_in_dim(T151, shape=V156, broadcast_dims=[0, 1, 2])
    S158 = fd.define_scalar(1, dtype=DataType.Int)
    S159 = fd.define_scalar(96, dtype=DataType.Int)
    S160 = fd.define_scalar(2048, dtype=DataType.Int)
    S161 = fd.define_scalar(2048, dtype=DataType.Int)
    V162 = fd.define_vector([S158, S159, S160, S161], dtype=DataType.Int)
    T163 = fd.ops.broadcast_in_dim(T157, shape=V162, broadcast_dims=[0, 1, 2, 3])
    T164 = fd.ops.reciprocal(T163)
    T165 = fd.ops.mul(T150, T164)
    T166 = fd.ops.cast(T165, dtype=DataType.BFloat16)
    S167 = fd.define_scalar(0.00000, dtype=DataType.Double)
    S168 = fd.define_scalar(1.00000, dtype=DataType.Double)
    S169 = fd.define_scalar(1, dtype=DataType.Int)
    S170 = fd.define_scalar(96, dtype=DataType.Int)
    S171 = fd.define_scalar(2048, dtype=DataType.Int)
    S172 = fd.define_scalar(2048, dtype=DataType.Int)
    V173 = fd.define_vector([S169, S170, S171, S172], dtype=DataType.Int)
    T174 = fd.ops.uniform(S167, S168, shape=V173, dtype=DataType.BFloat16)
    S175 = fd.define_scalar(0.900000, dtype=DataType.Double)
    T176 = fd.ops.lt(T174, S175)
    T177 = fd.ops.cast(T176, dtype=DataType.Float)
    T178 = fd.ops.mul(T165, T177)
    S179 = fd.define_scalar(1.11111, dtype=DataType.Double)
    T180 = fd.ops.mul(T178, S179)
    T181 = fd.ops.cast(T180, dtype=DataType.BFloat16)
    T182 = fd.ops.matmul(T181, T89)
    T183 = fd.ops.permute(T182, dims=[0, 2, 1, 3])
    T184 = fd.ops.stride_order(T183, stride_order=[3, 2, 1, 0])
    S185 = fd.define_scalar(1, dtype=DataType.Int)
    S186 = fd.define_scalar(2048, dtype=DataType.Int)
    S187 = fd.define_scalar(12288, dtype=DataType.Int)
    V188 = fd.define_vector([S185, S186, S187], dtype=DataType.Int)
    T189 = fd.ops.reshape(T184, new_shape=V188)
    S190 = fd.define_scalar(2048, dtype=DataType.Int)
    S191 = fd.define_scalar(12288, dtype=DataType.Int)
    V192 = fd.define_vector([S190, S191], dtype=DataType.Int)
    T193 = fd.ops.reshape(T189, new_shape=V192)
    T194 = fd.ops.linear(T193, T3, T2)
    S195 = fd.define_scalar(1, dtype=DataType.Int)
    S196 = fd.define_scalar(2048, dtype=DataType.Int)
    S197 = fd.define_scalar(12288, dtype=DataType.Int)
    V198 = fd.define_vector([S195, S196, S197], dtype=DataType.Int)
    T199 = fd.ops.reshape(T194, new_shape=V198)
    S200 = fd.define_scalar(0.00000, dtype=DataType.Double)
    S201 = fd.define_scalar(1.00000, dtype=DataType.Double)
    S202 = fd.define_scalar(1, dtype=DataType.Int)
    S203 = fd.define_scalar(2048, dtype=DataType.Int)
    S204 = fd.define_scalar(12288, dtype=DataType.Int)
    V205 = fd.define_vector([S202, S203, S204], dtype=DataType.Int)
    T206 = fd.ops.uniform(S200, S201, shape=V205, dtype=DataType.BFloat16)
    S207 = fd.define_scalar(0.900000, dtype=DataType.Double)
    T208 = fd.ops.lt(T206, S207)
    T209 = fd.ops.cast(T199, dtype=DataType.Float)
    T210 = fd.ops.cast(T208, dtype=DataType.Float)
    T211 = fd.ops.mul(T209, T210)
    S212 = fd.define_scalar(1.11111, dtype=DataType.Double)
    T213 = fd.ops.mul(T211, S212)
    T214 = fd.ops.add(T13, T213)
    T215, T216 = fd.ops.var_mean(T214, dims=[2], correction=0, keepdim=False)
    S217 = fd.define_scalar(1, dtype=DataType.Int)
    S218 = fd.define_scalar(2048, dtype=DataType.Int)
    S219 = fd.define_scalar(1, dtype=DataType.Int)
    V220 = fd.define_vector([S217, S218, S219], dtype=DataType.Int)
    T221 = fd.ops.broadcast_in_dim(T215, shape=V220, broadcast_dims=[0, 1])
    S222 = fd.define_scalar(1, dtype=DataType.Int)
    S223 = fd.define_scalar(2048, dtype=DataType.Int)
    S224 = fd.define_scalar(1, dtype=DataType.Int)
    V225 = fd.define_vector([S222, S223, S224], dtype=DataType.Int)
    T226 = fd.ops.broadcast_in_dim(T216, shape=V225, broadcast_dims=[0, 1])
    S227 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T228 = fd.ops.add(T221, S227)
    T229 = fd.ops.rsqrt(T228)
    S230 = fd.define_scalar(1, dtype=DataType.Int)
    S231 = fd.define_scalar(2048, dtype=DataType.Int)
    S232 = fd.define_scalar(12288, dtype=DataType.Int)
    V233 = fd.define_vector([S230, S231, S232], dtype=DataType.Int)
    T234 = fd.ops.broadcast_in_dim(T226, shape=V233, broadcast_dims=[0, 1, 2])
    T235 = fd.ops.sub(T214, T234)
    S236 = fd.define_scalar(1, dtype=DataType.Int)
    S237 = fd.define_scalar(2048, dtype=DataType.Int)
    S238 = fd.define_scalar(12288, dtype=DataType.Int)
    V239 = fd.define_vector([S236, S237, S238], dtype=DataType.Int)
    T240 = fd.ops.broadcast_in_dim(T229, shape=V239, broadcast_dims=[0, 1, 2])
    T241 = fd.ops.mul(T235, T240)
    S242 = fd.define_scalar(1, dtype=DataType.Int)
    S243 = fd.define_scalar(2048, dtype=DataType.Int)
    S244 = fd.define_scalar(12288, dtype=DataType.Int)
    V245 = fd.define_vector([S242, S243, S244], dtype=DataType.Int)
    T246 = fd.ops.broadcast_in_dim(T7, shape=V245, broadcast_dims=[2])
    T247 = fd.ops.cast(T246, dtype=DataType.Float)
    T248 = fd.ops.mul(T241, T247)
    S249 = fd.define_scalar(1, dtype=DataType.Int)
    S250 = fd.define_scalar(2048, dtype=DataType.Int)
    S251 = fd.define_scalar(12288, dtype=DataType.Int)
    V252 = fd.define_vector([S249, S250, S251], dtype=DataType.Int)
    T253 = fd.ops.broadcast_in_dim(T6, shape=V252, broadcast_dims=[2])
    T254 = fd.ops.cast(T253, dtype=DataType.Float)
    T255 = fd.ops.add(T248, T254)
    T256 = fd.ops.cast(T255, dtype=DataType.BFloat16)
    S257 = fd.define_scalar(2048, dtype=DataType.Int)
    S258 = fd.define_scalar(12288, dtype=DataType.Int)
    V259 = fd.define_vector([S257, S258], dtype=DataType.Int)
    T260 = fd.ops.reshape(T256, new_shape=V259)
    T261 = fd.ops.linear(T260, T9, T8)
    S262 = fd.define_scalar(1, dtype=DataType.Int)
    S263 = fd.define_scalar(2048, dtype=DataType.Int)
    S264 = fd.define_scalar(49152, dtype=DataType.Int)
    V265 = fd.define_vector([S262, S263, S264], dtype=DataType.Int)
    T266 = fd.ops.reshape(T261, new_shape=V265)
    T267 = fd.ops.cast(T266, dtype=DataType.Float)
    T268 = fd.ops.mul(T267, T267)
    T269 = fd.ops.mul(T268, T267)
    S270 = fd.define_scalar(0.500000, dtype=DataType.Double)
    T271 = fd.ops.mul(S270, T267)
    S272 = fd.define_scalar(0.0447150, dtype=DataType.Double)
    T273 = fd.ops.mul(S272, T269)
    T274 = fd.ops.add(T267, T273)
    S275 = fd.define_scalar(0.797885, dtype=DataType.Double)
    T276 = fd.ops.mul(S275, T274)
    T277 = fd.ops.tanh(T276)
    S278 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T279 = fd.ops.add(S278, T277)
    T280 = fd.ops.mul(T271, T279)
    T281 = fd.ops.cast(T280, dtype=DataType.BFloat16)
    S282 = fd.define_scalar(2048, dtype=DataType.Int)
    S283 = fd.define_scalar(49152, dtype=DataType.Int)
    V284 = fd.define_vector([S282, S283], dtype=DataType.Int)
    T285 = fd.ops.reshape(T281, new_shape=V284)
    T286 = fd.ops.linear(T285, T11, T10)
    S287 = fd.define_scalar(1, dtype=DataType.Int)
    S288 = fd.define_scalar(2048, dtype=DataType.Int)
    S289 = fd.define_scalar(12288, dtype=DataType.Int)
    V290 = fd.define_vector([S287, S288, S289], dtype=DataType.Int)
    T291 = fd.ops.reshape(T286, new_shape=V290)
    S292 = fd.define_scalar(0.00000, dtype=DataType.Double)
    S293 = fd.define_scalar(1.00000, dtype=DataType.Double)
    S294 = fd.define_scalar(1, dtype=DataType.Int)
    S295 = fd.define_scalar(2048, dtype=DataType.Int)
    S296 = fd.define_scalar(12288, dtype=DataType.Int)
    V297 = fd.define_vector([S294, S295, S296], dtype=DataType.Int)
    T298 = fd.ops.uniform(S292, S293, shape=V297, dtype=DataType.BFloat16)
    S299 = fd.define_scalar(0.900000, dtype=DataType.Double)
    T300 = fd.ops.lt(T298, S299)
    T301 = fd.ops.cast(T291, dtype=DataType.Float)
    T302 = fd.ops.cast(T300, dtype=DataType.Float)
    T303 = fd.ops.mul(T301, T302)
    S304 = fd.define_scalar(1.11111, dtype=DataType.Double)
    T305 = fd.ops.mul(T303, S304)
    T306 = fd.ops.add(T214, T305)
    T307 = fd.ops.cast(T306, dtype=DataType.BFloat16)
    fd.add_output(T216)
    fd.add_output(T229)
    fd.add_output(T300)
    fd.add_output(T307)
    fd.add_output(T15)
    fd.add_output(T166)
    fd.add_output(T176)
    fd.add_output(T28)
    fd.add_output(T181)
    fd.add_output(T208)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn((36864,), dtype=torch.bfloat16, device='cuda:0').as_strided((36864,), (1,)),
    torch.randn((452984832,), dtype=torch.bfloat16, device='cuda:0').as_strided((36864, 12288), (12288, 1)),
    torch.randn((12288,), dtype=torch.bfloat16, device='cuda:0').as_strided((12288,), (1,)),
    torch.randn((150994944,), dtype=torch.bfloat16, device='cuda:0').as_strided((12288, 12288), (12288, 1)),
    torch.randn((12288,), dtype=torch.bfloat16, device='cuda:0').as_strided((12288,), (1,)),
    torch.randn((12288,), dtype=torch.bfloat16, device='cuda:0').as_strided((12288,), (1,)),
    torch.randn((12288,), dtype=torch.bfloat16, device='cuda:0').as_strided((12288,), (1,)),
    torch.randn((12288,), dtype=torch.bfloat16, device='cuda:0').as_strided((12288,), (1,)),
    torch.randn((49152,), dtype=torch.bfloat16, device='cuda:0').as_strided((49152,), (1,)),
    torch.randn((603979776,), dtype=torch.bfloat16, device='cuda:0').as_strided((49152, 12288), (12288, 1)),
    torch.randn((12288,), dtype=torch.bfloat16, device='cuda:0').as_strided((12288,), (1,)),
    torch.randn((603979776,), dtype=torch.bfloat16, device='cuda:0').as_strided((12288, 49152), (49152, 1)),
    torch.randn((25165824,), dtype=torch.bfloat16, device='cuda:0').as_strided((1, 2048, 12288), (25165824, 12288, 1)),
]
fd.execute(inputs)
Traceback (most recent call last):
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 145, in execute
    result = self._execute(
RuntimeError: Squeezed IterDomain ?S536{1} must concretize to IterType::Broadcast but found ?S536{1}
Exception raised from checkConcretization at /opt/pytorch/nvfuser/csrc/ir/nodes.cpp:1406 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x74fdf03fca67 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #1: nvfuser::SqueezeOp::checkConcretization(nvfuser::Val*, nvfuser::Val*) const + 0x654 (0x74fdf08b7db4 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #2: <unknown function> + 0x41627b (0x74fdf06ec27b in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x41976a (0x74fdf06ef76a in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #4: nvfuser::DynamicTransform::concretizeFusion(nvfuser::Fusion*, nvfuser::DynamicTransformConcretizationInfo const*) + 0xa2 (0x74fdf06ef9e2 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #5: <unknown function> + 0x6601f2 (0x74fdf09361f2 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #6: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0x1e7 (0x74fdf09373f7 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #7: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, std::optional<signed char>, bool, bool, bool) const + 0x3ec (0x74fdf0b283fc in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #8: <unknown function> + 0x19e88e (0x74fdf047488e in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #9: <unknown function> + 0x2153ff (0x74fdf04eb3ff in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #10: <unknown function> + 0x2a9be0 (0x74fdf057fbe0 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #11: <unknown function> + 0x15a10e (0x5778b421b10e in /usr/bin/python3)
frame #12: _PyObject_MakeTpCall + 0x25b (0x5778b4211a7b in /usr/bin/python3)
frame #13: <unknown function> + 0x168acb (0x5778b4229acb in /usr/bin/python3)
frame #14: _PyEval_EvalFrameDefault + 0x198c (0x5778b420553c in /usr/bin/python3)
frame #15: <unknown function> + 0x1687f1 (0x5778b42297f1 in /usr/bin/python3)
frame #16: PyObject_Call + 0x122 (0x5778b422a492 in /usr/bin/python3)
frame #17: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #18: _PyObject_FastCallDictTstate + 0xc4 (0x5778b4210c14 in /usr/bin/python3)
frame #19: _PyObject_Call_Prepend + 0xc1 (0x5778b42268d1 in /usr/bin/python3)
frame #20: <unknown function> + 0x280700 (0x5778b4341700 in /usr/bin/python3)
frame #21: _PyObject_MakeTpCall + 0x25b (0x5778b4211a7b in /usr/bin/python3)
frame #22: _PyEval_EvalFrameDefault + 0x64e6 (0x5778b420a096 in /usr/bin/python3)
frame #23: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #24: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #25: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #26: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #27: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #28: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #29: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #30: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #31: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #32: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #33: <unknown function> + 0x16893e (0x5778b422993e in /usr/bin/python3)
frame #34: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #35: <unknown function> + 0x16893e (0x5778b422993e in /usr/bin/python3)
frame #36: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #37: _PyObject_FastCallDictTstate + 0xc4 (0x5778b4210c14 in /usr/bin/python3)
frame #38: _PyObject_Call_Prepend + 0x5c (0x5778b422686c in /usr/bin/python3)
frame #39: <unknown function> + 0x280700 (0x5778b4341700 in /usr/bin/python3)
frame #40: PyObject_Call + 0xbb (0x5778b422a42b in /usr/bin/python3)
frame #41: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #42: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #43: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #44: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #45: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #46: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #47: _PyEval_EvalFrameDefault + 0x6bd (0x5778b420426d in /usr/bin/python3)
frame #48: <unknown function> + 0x1687f1 (0x5778b42297f1 in /usr/bin/python3)
frame #49: _PyEval_EvalFrameDefault + 0x198c (0x5778b420553c in /usr/bin/python3)
frame #50: <unknown function> + 0x1687f1 (0x5778b42297f1 in /usr/bin/python3)
frame #51: _PyEval_EvalFrameDefault + 0x198c (0x5778b420553c in /usr/bin/python3)
frame #52: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #53: PyObject_Call + 0x122 (0x5778b422a492 in /usr/bin/python3)
frame #54: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #55: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #56: _PyEval_EvalFrameDefault + 0x2a27 (0x5778b42065d7 in /usr/bin/python3)
frame #57: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #58: _PyEval_EvalFrameDefault + 0x614a (0x5778b4209cfa in /usr/bin/python3)
frame #59: <unknown function> + 0x1687f1 (0x5778b42297f1 in /usr/bin/python3)
frame #60: _PyEval_EvalFrameDefault + 0x614a (0x5778b4209cfa in /usr/bin/python3)
frame #61: _PyFunction_Vectorcall + 0x7c (0x5778b421b9fc in /usr/bin/python3)
frame #62: _PyObject_FastCallDictTstate + 0x16d (0x5778b4210cbd in /usr/bin/python3)
frame #63: _PyObject_Call_Prepend + 0x5c (0x5778b422686c in /usr/bin/python3)

@jjsjann123 tagging myself. I think it's the reshape that's not specifying the output iterdomain properly.
Let me see if I can simplify the example for @jacobhinkle

Also note to myself. how would reshape work for dynamic scalar? should we special case instances where we have 1 for a scalar? should this be a operations imposed constraint to add in prologue trace.

reshape(a, new_shape=[i, j, k]). if we encounter an entry in new_shape that's 1, we cannot have that as a symbolic symbol down in the road, otherwise we might also run into squeeze asserting on that? Or does this mean we should hav relaxed the check in squeeze.... Linking issue Lightning-AI/lightning-thunder#262

I suspect define_tensor needs to say shape=[1,...] when the batch size is one?

Never mind. https://github.com/Lightning-AI/lightning-thunder/blob/126940750c8e498a89376e6c787985448c79808a/thunder/executors/nvfuserex_impl.py#L299 indeed kicked in. The first dimension of T12 is 1 in the above example.

I don't see any issue with this. The reshape with size '1, xxx' is doing the right thing about translating to a broadcast.

A smaller repro.

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.ops.slice(T0, start_indices=[0, 0, 0], end_indices=[1, 2, 4], strides=[1, 1, 1])
    S2 = fd.define_scalar(1, dtype=DataType.Int)
    S3 = fd.define_scalar(8, dtype=DataType.Int)
    V4 = fd.define_vector([S2, S3], dtype=DataType.Int)
    V5 = fd.define_vector([S3], dtype=DataType.Int)
    T6 = fd.ops.reshape(T1, new_shape=V4)
    T7 = fd.ops.reshape(T6, new_shape=V5)
    # this works fine
    # T7 = fd.ops.reshape(T1, new_shape=V5)
    fd.add_output(T7)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn((100,), dtype=torch.float32, device='cuda:0').as_strided((2, 5, 10), (50, 10, 1)),
]
fd.execute(inputs)

cc'ing @jacobhinkle looks like indeed a concretization bug.

I'm back at keyboard. Looks like this is just a check that we should avoid at least in this case. Commenting out the call to checkConcretizedUses inside concretizeReshape makes the test succeed. I'll push a PR soon.

A smaller repro.

Thank you!