fails to print multi-device arrays
GallagherCommaJack opened this issue · comments
it's not 100% clear what the right behavior here is, but i'm confident the current behavior isn't it
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-SXSjnun_-py3.8/lib/python3.8/site-packages/IPython/core/formatters.py:706, in PlainTextFormatter.__call__(self, obj)
699 stream = StringIO()
700 printer = pretty.RepresentationPrinter(stream, self.verbose,
701 self.max_width, self.newline,
702 max_seq_length=self.max_seq_length,
703 singleton_pprinters=self.singleton_printers,
704 type_pprinters=self.type_printers,
705 deferred_pprinters=self.deferred_printers)
--> 706 printer.pretty(obj)
707 printer.flush()
708 return stream.getvalue()
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-SXSjnun_-py3.8/lib/python3.8/site-packages/IPython/lib/pretty.py:393, in RepresentationPrinter.pretty(self, obj)
390 for cls in _get_mro(obj_class):
391 if cls in self.type_pprinters:
392 # printer registered in self.type_pprinters
--> 393 return self.type_pprinters[cls](obj, self, cycle)
394 else:
395 # deferred printer
396 printer = self._in_deferred_types(cls)
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-SXSjnun_-py3.8/lib/python3.8/site-packages/IPython/lib/pretty.py:640, in _seq_pprinter_factory.<locals>.inner(obj, p, cycle)
638 p.text(',')
639 p.breakable()
--> 640 p.pretty(x)
641 if len(obj) == 1 and isinstance(obj, tuple):
642 # Special case for 1-item tuples.
643 p.text(',')
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-SXSjnun_-py3.8/lib/python3.8/site-packages/IPython/lib/pretty.py:410, in RepresentationPrinter.pretty(self, obj)
407 return meth(obj, self, cycle)
408 if cls is not object \
409 and callable(cls.__dict__.get('__repr__')):
--> 410 return _repr_pprint(obj, self, cycle)
412 return _default_pprint(obj, self, cycle)
413 finally:
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-SXSjnun_-py3.8/lib/python3.8/site-packages/IPython/lib/pretty.py:778, in _repr_pprint(obj, p, cycle)
776 """A pprint that just redirects to the normal repr function."""
777 # Find newlines and replace them with p.break_()
--> 778 output = repr(obj)
779 lines = output.splitlines()
780 with p.group():
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-SXSjnun_-py3.8/lib/python3.8/site-packages/lovely_jax/patch.py:29, in _monkey_patch.<locals>.__repr__(self)
27 @patch_to(cls)
28 def __repr__(self: jax.Array):
---> 29 return str(StrProxy(self))
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-SXSjnun_-py3.8/lib/python3.8/site-packages/lovely_jax/repr_str.py:175, in StrProxy.__repr__(self)
174 def __repr__(self):
--> 175 return to_str(self.x, plain=self.plain, verbose=self.verbose,
176 depth=self.depth, lvl=self.lvl, color=self.color)
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-SXSjnun_-py3.8/lib/python3.8/site-packages/lovely_jax/repr_str.py:107, in to_str(x, plain, verbose, depth, lvl, color)
103 shape = str(list(x.shape)) if x.ndim else None
104 type_str = sparse_join([tname, shape], sep="")
--> 107 dev = f"{x.device().platform}:{x.device().id}"
108 dtype = short_dtype(x)
109 # grad_fn = t.grad_fn.name() if t.grad_fn else None
110 # PyTorch does not want you to know, but all `grad_fn``
111 # tensors actuall have `requires_grad=True`` too.
112 # grad = "grad" if t.requires_grad else None
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-SXSjnun_-py3.8/lib/python3.8/site-packages/jax/_src/array.py:370, in ArrayImpl.device(self)
368 single_device, = device_set
369 return single_device
--> 370 raise ValueError('Length of devices is greater than 1. '
371 'Please use `.devices()`.')
ValueError: Length of devices is greater than 1. Please use `.devices()`.
Thank you for the report.
I'm working on sharded arrays at the moment, will give you an update soon.
@GallagherCommaJack , sorry it took so long. Could you please confirm that the current git
pip install git+https://github.com/xl0/lovely-jax
works for you?
I'm not 100% sure what's the best way to display the array layout, so for now it just lists the devices.
Do you have any ideas on what would be a better way?
The fix is in the last release, which also fixes the Jax 0.4.8 issues.