tinygrad / tinygrad

You like pytorch? You like micrograd? You love tinygrad! ❤️

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

multi uneven shard failed with invalid shrink

chenyuxyz opened this issue · comments

repro

device = [f"{Device.DEFAULT}:{i}" for i in range(4)]
t = Tensor([1, 2, 3, 4]).shard(device, axis=0).realize()
print(f"{t.numpy()}")
t = Tensor([1, 2, 3]).shard(device, axis=0).realize()
print(f"{t.numpy()}")
t = Tensor([1, 2]).shard(device, axis=0).realize()
print(f"{t.numpy()}")

output

[1 2 3 4]
[1 2 3]
Traceback (most recent call last):
  File "/Users/chenyu/code/tinygrad/test.py", line 18, in <module>
    t = Tensor([1, 2]).shard(device, axis=0).realize()
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/chenyu/code/tinygrad/tinygrad/tensor.py", line 328, in shard
    return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices, requires_grad=self.requires_grad)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/chenyu/code/tinygrad/tinygrad/multi.py", line 75, in from_sharded
    sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)]
                                                        ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/chenyu/code/tinygrad/tinygrad/multi.py", line 48, in to_sharded
    return [lb.shrink(tuple((0,s) if a != axis else (sz*i,min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/chenyu/code/tinygrad/tinygrad/multi.py", line 48, in <listcomp>
    return [lb.shrink(tuple((0,s) if a != axis else (sz*i,min(s,sz*(i+1))) for a,s in enumerate(lb.shape))) for i,lb in enumerate(lbs)]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/chenyu/code/tinygrad/tinygrad/lazy.py", line 219, in shrink
    def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
                                                                           ^^^^^^^^^^^^^^^^^^^
  File "/Users/chenyu/code/tinygrad/tinygrad/shape/shapetracker.py", line 114, in shrink
    def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker: return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg), ))
                                                                                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/chenyu/code/tinygrad/tinygrad/shape/view.py", line 232, in shrink
    assert all((0<=b<=e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}"
                                                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: invalid shrink ((3, 2),) for (2,)

fixed with #5131