xl0 / lovely-jax

JAX Arrays for human consumption

Home Page:https://xl0.github.io/lovely-jax

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

fails to print multi-device arrays

GallagherCommaJack opened this issue · comments

it's not 100% clear what the right behavior here is, but i'm confident the current behavior isn't it

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-SXSjnun_-py3.8/lib/python3.8/site-packages/IPython/core/formatters.py:706, in PlainTextFormatter.__call__(self, obj)
    699 stream = StringIO()
    700 printer = pretty.RepresentationPrinter(stream, self.verbose,
    701     self.max_width, self.newline,
    702     max_seq_length=self.max_seq_length,
    703     singleton_pprinters=self.singleton_printers,
    704     type_pprinters=self.type_printers,
    705     deferred_pprinters=self.deferred_printers)
--> 706 printer.pretty(obj)
    707 printer.flush()
    708 return stream.getvalue()

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-SXSjnun_-py3.8/lib/python3.8/site-packages/IPython/lib/pretty.py:393, in RepresentationPrinter.pretty(self, obj)
    390 for cls in _get_mro(obj_class):
    391     if cls in self.type_pprinters:
    392         # printer registered in self.type_pprinters
--> 393         return self.type_pprinters[cls](obj, self, cycle)
    394     else:
    395         # deferred printer
    396         printer = self._in_deferred_types(cls)

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-SXSjnun_-py3.8/lib/python3.8/site-packages/IPython/lib/pretty.py:640, in _seq_pprinter_factory.<locals>.inner(obj, p, cycle)
    638         p.text(',')
    639         p.breakable()
--> 640     p.pretty(x)
    641 if len(obj) == 1 and isinstance(obj, tuple):
    642     # Special case for 1-item tuples.
    643     p.text(',')

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-SXSjnun_-py3.8/lib/python3.8/site-packages/IPython/lib/pretty.py:410, in RepresentationPrinter.pretty(self, obj)
    407                         return meth(obj, self, cycle)
    408                 if cls is not object \
    409                         and callable(cls.__dict__.get('__repr__')):
--> 410                     return _repr_pprint(obj, self, cycle)
    412     return _default_pprint(obj, self, cycle)
    413 finally:

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-SXSjnun_-py3.8/lib/python3.8/site-packages/IPython/lib/pretty.py:778, in _repr_pprint(obj, p, cycle)
    776 """A pprint that just redirects to the normal repr function."""
    777 # Find newlines and replace them with p.break_()
--> 778 output = repr(obj)
    779 lines = output.splitlines()
    780 with p.group():

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-SXSjnun_-py3.8/lib/python3.8/site-packages/lovely_jax/patch.py:29, in _monkey_patch.<locals>.__repr__(self)
     27 @patch_to(cls)
     28 def __repr__(self: jax.Array):
---> 29     return str(StrProxy(self))

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-SXSjnun_-py3.8/lib/python3.8/site-packages/lovely_jax/repr_str.py:175, in StrProxy.__repr__(self)
    174 def __repr__(self):
--> 175     return to_str(self.x, plain=self.plain, verbose=self.verbose,
    176                   depth=self.depth, lvl=self.lvl, color=self.color)

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-SXSjnun_-py3.8/lib/python3.8/site-packages/lovely_jax/repr_str.py:107, in to_str(x, plain, verbose, depth, lvl, color)
    103 shape = str(list(x.shape)) if x.ndim else None
    104 type_str = sparse_join([tname, shape], sep="")
--> 107 dev = f"{x.device().platform}:{x.device().id}"
    108 dtype = short_dtype(x)
    109 # grad_fn = t.grad_fn.name() if t.grad_fn else None
    110 # PyTorch does not want you to know, but all `grad_fn``
    111 # tensors actuall have `requires_grad=True`` too.
    112 # grad = "grad" if t.requires_grad else None 

File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-SXSjnun_-py3.8/lib/python3.8/site-packages/jax/_src/array.py:370, in ArrayImpl.device(self)
    368   single_device, = device_set
    369   return single_device
--> 370 raise ValueError('Length of devices is greater than 1. '
    371                  'Please use `.devices()`.')

ValueError: Length of devices is greater than 1. Please use `.devices()`.

Thank you for the report.

I'm working on sharded arrays at the moment, will give you an update soon.

@GallagherCommaJack , sorry it took so long. Could you please confirm that the current git

pip install git+https://github.com/xl0/lovely-jax

works for you?

I'm not 100% sure what's the best way to display the array layout, so for now it just lists the devices.
Do you have any ideas on what would be a better way?

The fix is in the last release, which also fixes the Jax 0.4.8 issues.