multi uneven shard failed with invalid shrink
chenyuxyz opened this issue · comments
chenyu commented
repro
device = [f"{Device.DEFAULT}:{i}" for i in range(4)]
t = Tensor([1, 2, 3, 4]).shard(device, axis=0).realize()
print(f"{t.numpy()}")
t = Tensor([1, 2, 3]).shard(device, axis=0).realize()
print(f"{t.numpy()}")
t = Tensor([1, 2]).shard(device, axis=0).realize()
print(f"{t.numpy()}")
output
[1 2 3 4]
[1 2 3]
Traceback (most recent call last):
File "/Users/chenyu/code/tinygrad/test.py", line 18, in <module>
t = Tensor([1, 2]).shard(device, axis=0).realize()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/chenyu/code/tinygrad/tinygrad/tensor.py", line 328, in shard
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices, requires_grad=self.requires_grad)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/chenyu/code/tinygrad/tinygrad/multi.py", line 75, in from_sharded
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)]
^^^^^^^^^^^^^^^^^^^^^
File "/Users/chenyu/code/tinygrad/tinygrad/multi.py", line 48, in to_sharded
return [lb.shrink(tuple((0,s) if a != axis else (sz*i,min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/chenyu/code/tinygrad/tinygrad/multi.py", line 48, in <listcomp>
return [lb.shrink(tuple((0,s) if a != axis else (sz*i,min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/chenyu/code/tinygrad/tinygrad/lazy.py", line 219, in shrink
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
^^^^^^^^^^^^^^^^^^^
File "/Users/chenyu/code/tinygrad/tinygrad/shape/shapetracker.py", line 114, in shrink
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/chenyu/code/tinygrad/tinygrad/shape/view.py", line 232, in shrink
assert all((0<=b<=e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}"
^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: invalid shrink ((3, 2),) for (2,)