zhuyifei1999 / guppy3

guppy / heapy ported to Python3. It works for real!

Home Page:https://zhuyifei1999.github.io/guppy3/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Usage with JAX

Edenhofer opened this issue · comments

I would like to dive deeper into the memory consumption of a program using JAX's numpy. However, with guppy I am unfortunately not able to get out any information since the formatting of the result fails during stringifying and printing the result. One can easily reproduce the error using the following simple program:

from jax import numpy as jnp

from guppy import hpy

a = jnp.linspace(0.0, 1.0)

h = hpy()  # succeeds
print(h.heap())

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/lib/python3.9/site-packages/guppy/heapy/UniSet.py", line 349, in __str__
    return self.fam.c_str(self)
  File "/usr/lib/python3.9/site-packages/guppy/heapy/UniSet.py", line 1616, in c_str
    return a.more._oh_printer.get_str_of_top()
  File "/usr/lib/python3.9/site-packages/guppy/heapy/UniSet.py", line 74, in __getattr__
    return self.fam.mod.View.enter(lambda: self.fam.c_getattr(self, other))
  File "/usr/lib/python3.9/site-packages/guppy/heapy/View.py", line 256, in enter
    retval = func()
  File "/usr/lib/python3.9/site-packages/guppy/heapy/UniSet.py", line 74, in <lambda>
    return self.fam.mod.View.enter(lambda: self.fam.c_getattr(self, other))
  File "/usr/lib/python3.9/site-packages/guppy/heapy/UniSet.py", line 800, in c_getattr
    return self.c_getattr2(a, b)
  File "/usr/lib/python3.9/site-packages/guppy/heapy/UniSet.py", line 803, in c_getattr2
    raise AttributeError(b)
AttributeError: more

I know it might be quite a stretch to ask guppy to support custom libraries like JAX so I would already be happy about insights into how to resolve it myself or at least how to work-around it.

Sorry for the late response. I missed the message.

This took a bit to debug and this is the true error:

  File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/usr/lib/python3.9/pdb.py", line 1750, in <module>
    pdb.main()
  File "/usr/lib/python3.9/pdb.py", line 1723, in main
    pdb._runscript(mainpyfile)
  File "/usr/lib/python3.9/pdb.py", line 1583, in _runscript
    self.run(statement)
  File "/usr/lib/python3.9/bdb.py", line 580, in run
    exec(cmd, globals, locals)
  File "<string>", line 1, in <module>
  File "/home/zhuyifei1999/guppy3/test.py", line 1, in <module>
    from jax import numpy as jnp
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 349, in __str__
    return self.fam.c_str(self)
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 1623, in c_str
    return a.more._oh_printer.get_str_of_top()
  File "/home/zhuyifei1999/guppy3/guppy/etc/Descriptor.py", line 32, in __get__
    return super().__get__(instance, owner)
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 516, in <lambda>
    more = property_exp(lambda self: self.fam.get_more(self), doc="""\
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 1704, in get_more
    m = self.mod.OutputHandling.more_printer(a, a.partition)
  File "/home/zhuyifei1999/guppy3/guppy/etc/Descriptor.py", line 32, in __get__
    return super().__get__(instance, owner)
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 531, in <lambda>
    partition = property_exp(lambda self: self.fam.get_partition(self), doc="""\
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 1719, in get_partition
    p = a.fam.Part.partition(a, a.er)
  File "/home/zhuyifei1999/guppy3/guppy/heapy/Part.py", line 763, in partition
    return SetPartition(self, set, er)
  File "/home/zhuyifei1999/guppy3/guppy/heapy/Part.py", line 672, in __init__
    tosort = [(-part.size, classifier.get_tabrendering(kind, ''), kind, part)
  File "/home/zhuyifei1999/guppy3/guppy/heapy/Part.py", line 672, in <listcomp>
    tosort = [(-part.size, classifier.get_tabrendering(kind, ''), kind, part)
  File "/home/zhuyifei1999/guppy3/guppy/heapy/Classifiers.py", line 91, in get_tabrendering
    return cla.brief
  File "/home/zhuyifei1999/guppy3/guppy/etc/Descriptor.py", line 32, in __get__
    return super().__get__(instance, owner)
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 213, in <lambda>
    brief = property_exp(lambda self: self.fam.c_get_brief(self),
  File "/home/zhuyifei1999/guppy3/guppy/heapy/Classifiers.py", line 509, in c_get_brief
    return 'dict of ' + ka.brief
  File "/home/zhuyifei1999/guppy3/guppy/etc/Descriptor.py", line 32, in __get__
    return super().__get__(instance, owner)
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 213, in <lambda>
    brief = property_exp(lambda self: self.fam.c_get_brief(self),
  File "/home/zhuyifei1999/guppy3/guppy/heapy/Classifiers.py", line 409, in c_get_brief
    __import__('traceback').print_stack()
enter
Traceback (most recent call last):
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 1702, in get_more
    m = a._more
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 74, in __getattr__
    return self.fam.mod.View.enter(lambda: self.fam.c_getattr(self, other))
  File "/home/zhuyifei1999/guppy3/guppy/heapy/View.py", line 256, in enter
    retval = func()
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 74, in <lambda>
    return self.fam.mod.View.enter(lambda: self.fam.c_getattr(self, other))
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 800, in c_getattr
    return self.c_getattr2(a, b)
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 803, in c_getattr2
    raise AttributeError(b)
AttributeError: _more

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 1716, in get_partition
    p = a._partition
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 74, in __getattr__
    return self.fam.mod.View.enter(lambda: self.fam.c_getattr(self, other))
  File "/home/zhuyifei1999/guppy3/guppy/heapy/View.py", line 256, in enter
    retval = func()
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 74, in <lambda>
    return self.fam.mod.View.enter(lambda: self.fam.c_getattr(self, other))
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 800, in c_getattr
    return self.c_getattr2(a, b)
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 803, in c_getattr2
    raise AttributeError(b)
AttributeError: _partition

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/zhuyifei1999/guppy3/guppy/heapy/Classifiers.py", line 412, in c_get_brief
    return self.mod.summary_str(type(a.arg))(a.arg)
  File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 2081, in str_type
    return '%s.%s' % (x.__module__, x.__name__)
AttributeError: __module__
> /home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py(2081)str_type()
-> return '%s.%s' % (x.__module__, x.__name__)
(Pdb) p x
<class 'CompiledFunction'>

The error is hidden because when python's descriptor's __get__ raises an AttributeError and the class has __getattr__ which also raises AttributeError, the latter exception masks the former without setting a 'During handling of the above exception, another exception occurred'

I am afraid I now very little about guppy to fully interpret this error message. Am I correct with the interpretation that guppy aborts because class 'CompiledFunction' is missing a __module__ attribute which in turn is required to fulfill IdentitySetFamily.get_more?

CompiledFunction is in JAX with this in its PyTypeObject:

gef➤  p *(PyTypeObject *)93824997935776
$2 = {
  ob_base = {
    ob_base = {
      ob_refcnt = 0x120,
      ob_type = 0x7ffff7f76040 <PyType_Type>
    },
    ob_size = 0x0
  },
  tp_name = 0x7ffff4ea0652 "CompiledFunction",
  tp_basicsize = 0xc8,
  tp_itemsize = 0x0,
  tp_dealloc = 0x7fffef683e80 <JaxCompiledFunction_tp_dealloc>,
  tp_vectorcall_offset = 0x0,
  tp_getattr = 0x0,
  tp_setattr = 0x0,
  tp_as_async = 0x0,
  tp_repr = 0x7fffef685950 <JaxCompiledFunction_tp_repr>,
  tp_as_number = 0x0,
  tp_as_sequence = 0x0,
  tp_as_mapping = 0x0,
  tp_hash = 0x7ffff7c2bfc0 <_Py_HashPointer>,
  tp_call = 0x7fffef69d310 <JaxCompiledFunction_tp_call>,
  tp_str = 0x7ffff7c2a8a0 <object_str>,
  tp_getattro = 0x7ffff7db5960 <PyObject_GenericGetAttr>,
  tp_setattro = 0x7ffff7db4620 <PyObject_GenericSetAttr>,
  tp_as_buffer = 0x0,
  tp_flags = 0xc5200,
  tp_doc = 0x0,
  tp_traverse = 0x7fffef682870 <JaxCompiledFunction_tp_traverse>,
  tp_clear = 0x7fffef682fb0 <JaxCompiledFunction_tp_clear>,
  tp_richcompare = 0x7ffff7c3a440 <object_richcompare>,
  tp_weaklistoffset = 0x18,
  tp_iter = 0x0,
  tp_iternext = 0x0,
  tp_methods = 0x0,
  tp_members = 0x0,
  tp_getset = 0x7ffff62ff140 <jax::(anonymous namespace)::JaxCompiledFunction_tp_getset>,
  tp_base = 0x7ffff7f76380 <PyBaseObject_Type>,
  tp_dict = {'__repr__': <wrapper_descriptor at remote 0x7fffda1b56d0>, '__call__': <wrapper_descriptor at remote 0x7fffda1b5720>, '__get__': <wrapper_descriptor at remote 0x7fffda1b5770>, '__new__': <built-in method __new__ of type object at remote 0x555555ac4aa0>, '__dict__': <getset_descriptor at remote 0x7fffda1b3e80>, '__doc__': None, '__signature__': <property at remote 0x7fffda1b5900>, '_cache_miss': <property at remote 0x7fffda1b59a0>, '__getstate__': <instancemethod at remote 0x7fffda1b2ca0>, '__setstate__': <instancemethod at remote 0x7fffda1b2d00>, '_cache_size': <instancemethod at remote 0x7fffda1b7820>, '_clear_cache': <instancemethod at remote 0x7fffda1b7880>},
  tp_descr_get = 0x7fffef682840 <JaxCompiledFunction_tp_descr_get>,
  tp_descr_set = 0x0,
  tp_dictoffset = 0x10,
  tp_init = 0x7ffff7cebe80 <object_init>,
  tp_alloc = 0x7ffff7cc70c0 <PyType_GenericAlloc>,
  tp_new = 0x7fffef682560 <JaxCompiledFunction_tp_new>,
  tp_free = 0x7ffff7c2cc00 <PyObject_GC_Del>,
  tp_is_gc = 0x0,
  tp_bases = (<type at remote 0x7ffff7f76380>,),
  tp_mro = (<type at remote 0x555555ac4aa0>, <type at remote 0x7ffff7f76380>),
  tp_cache = 0x0,
  tp_subclasses = 0x0,
  tp_weaklist = <weakref at remote 0x7fffda1b5810>,
  tp_del = 0x0,
  tp_version_tag = 0xaaa,
  tp_finalize = 0x0,
  tp_vectorcall = 0x0
}

It seems like the relevant source code is https://github.com/tensorflow/tensorflow/blob/ebdbd61a48f8d9e438fd02fa31eb4075b25bd278/tensorflow/compiler/xla/python/jax_jit.cc#L1202-L1293

This is a heaptype, and the descriptor for a type's __module__ is https://github.com/python/cpython/blob/5a2a65096c3ec2d37f33615f2a420d2ffcabecf2/Objects/typeobject.c#L567-L596

If it's a heaptype it'll attempt to read from the type's __dict__['__module__']. If it doesn't exist then raise AttributeError. (If it's a non-heap type it'll attempt to parse the name for a "." and if it doesn't exist it'll return 'builtins'.)

The docs does not say whether the __module__ attribute is actually optional, but the source code of the inspect module seems to suggest so. Let me ask if they are willing to add __module__ to its dict, else I'll make an catch here.

I am afraid I now very little about guppy to fully interpret this error message. Am I correct with the interpretation that guppy aborts because class 'CompiledFunction' is missing a __module__ attribute which in turn is required to fulfill IdentitySetFamily.get_more?

Yes, that is correct. For example, if I do:

diff --git a/guppy/heapy/UniSet.py b/guppy/heapy/UniSet.py
index 216c22a..9c004b7 100644
--- a/guppy/heapy/UniSet.py
+++ b/guppy/heapy/UniSet.py
@@ -2071,6 +2071,8 @@ class Summary_str:
             return self.shorter_invtypes[x]
         if x in self.invtypes:
             return self.invtypes[x]
+        if not hasattr(x, '__module__'):
+            return f'<unknown module>.{x.__name__}'
         return '%s.%s' % (x.__module__, x.__name__)
     str_type._idpart_header = 'Name'
 

And the test:

from jax import numpy as jnp

from guppy import hpy

a = jnp.linspace(0.0, 1.0)

h = hpy()  # succeeds
print(h.heap().all)

I get: https://gist.github.com/zhuyifei1999/3fc04e7e4e4ab0dbfc29b2cf4426f63a

Let me ask if they are willing to add __module__ to its dict, else I'll make an catch here.

Awesome! Would you mind referencing the issue here as well? I would like to keep track of it as well.

The docs does not say whether the __module__ attribute is actually optional, but the source code of the inspect module seems to suggest so.

Considering that you pointed out that the inspect module treats __module__ as optional, would it not make sense to implement a fall-back mechanism e.g. similar to the one in inspect (maybe even with __file__) regardless of whether tensorflow adds the __module__ attribute to their CompiledFunction?

Awesome! Would you mind referencing the issue here as well? I would like to keep track of it as well.

tensorflow/tensorflow@f21d21f 😉

Considering that you pointed out that the inspect module treats __module__ as optional, would it not make sense to implement a fall-back mechanism e.g. similar to the one in inspect (maybe even with __file__) regardless of whether tensorflow adds the __module__ attribute to their CompiledFunction?

Hmm. Would an <unknown module>. prefix, like what I did above, make sense?

Awesome! Would you mind referencing the issue here as well? I would like to keep track of it as well.

tensorflow/tensorflow@f21d21f wink

That was fast! :)

Hmm. Would an <unknown module>. prefix, like what I did above, make sense?

I think <unknown module>. would be a good placeholder for anything that does not set __module__.

I think <unknown module>. would be a good placeholder for anything that does not set __module__.

Done.