Usage with JAX
Edenhofer opened this issue · comments
I would like to dive deeper into the memory consumption of a program using JAX's numpy. However, with guppy I am unfortunately not able to get out any information since the formatting of the result fails during stringifying and printing the result. One can easily reproduce the error using the following simple program:
from jax import numpy as jnp
from guppy import hpy
a = jnp.linspace(0.0, 1.0)
h = hpy() # succeeds
print(h.heap())
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/lib/python3.9/site-packages/guppy/heapy/UniSet.py", line 349, in __str__
return self.fam.c_str(self)
File "/usr/lib/python3.9/site-packages/guppy/heapy/UniSet.py", line 1616, in c_str
return a.more._oh_printer.get_str_of_top()
File "/usr/lib/python3.9/site-packages/guppy/heapy/UniSet.py", line 74, in __getattr__
return self.fam.mod.View.enter(lambda: self.fam.c_getattr(self, other))
File "/usr/lib/python3.9/site-packages/guppy/heapy/View.py", line 256, in enter
retval = func()
File "/usr/lib/python3.9/site-packages/guppy/heapy/UniSet.py", line 74, in <lambda>
return self.fam.mod.View.enter(lambda: self.fam.c_getattr(self, other))
File "/usr/lib/python3.9/site-packages/guppy/heapy/UniSet.py", line 800, in c_getattr
return self.c_getattr2(a, b)
File "/usr/lib/python3.9/site-packages/guppy/heapy/UniSet.py", line 803, in c_getattr2
raise AttributeError(b)
AttributeError: more
I know it might be quite a stretch to ask guppy to support custom libraries like JAX so I would already be happy about insights into how to resolve it myself or at least how to work-around it.
Sorry for the late response. I missed the message.
This took a bit to debug and this is the true error:
File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/usr/lib/python3.9/pdb.py", line 1750, in <module>
pdb.main()
File "/usr/lib/python3.9/pdb.py", line 1723, in main
pdb._runscript(mainpyfile)
File "/usr/lib/python3.9/pdb.py", line 1583, in _runscript
self.run(statement)
File "/usr/lib/python3.9/bdb.py", line 580, in run
exec(cmd, globals, locals)
File "<string>", line 1, in <module>
File "/home/zhuyifei1999/guppy3/test.py", line 1, in <module>
from jax import numpy as jnp
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 349, in __str__
return self.fam.c_str(self)
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 1623, in c_str
return a.more._oh_printer.get_str_of_top()
File "/home/zhuyifei1999/guppy3/guppy/etc/Descriptor.py", line 32, in __get__
return super().__get__(instance, owner)
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 516, in <lambda>
more = property_exp(lambda self: self.fam.get_more(self), doc="""\
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 1704, in get_more
m = self.mod.OutputHandling.more_printer(a, a.partition)
File "/home/zhuyifei1999/guppy3/guppy/etc/Descriptor.py", line 32, in __get__
return super().__get__(instance, owner)
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 531, in <lambda>
partition = property_exp(lambda self: self.fam.get_partition(self), doc="""\
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 1719, in get_partition
p = a.fam.Part.partition(a, a.er)
File "/home/zhuyifei1999/guppy3/guppy/heapy/Part.py", line 763, in partition
return SetPartition(self, set, er)
File "/home/zhuyifei1999/guppy3/guppy/heapy/Part.py", line 672, in __init__
tosort = [(-part.size, classifier.get_tabrendering(kind, ''), kind, part)
File "/home/zhuyifei1999/guppy3/guppy/heapy/Part.py", line 672, in <listcomp>
tosort = [(-part.size, classifier.get_tabrendering(kind, ''), kind, part)
File "/home/zhuyifei1999/guppy3/guppy/heapy/Classifiers.py", line 91, in get_tabrendering
return cla.brief
File "/home/zhuyifei1999/guppy3/guppy/etc/Descriptor.py", line 32, in __get__
return super().__get__(instance, owner)
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 213, in <lambda>
brief = property_exp(lambda self: self.fam.c_get_brief(self),
File "/home/zhuyifei1999/guppy3/guppy/heapy/Classifiers.py", line 509, in c_get_brief
return 'dict of ' + ka.brief
File "/home/zhuyifei1999/guppy3/guppy/etc/Descriptor.py", line 32, in __get__
return super().__get__(instance, owner)
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 213, in <lambda>
brief = property_exp(lambda self: self.fam.c_get_brief(self),
File "/home/zhuyifei1999/guppy3/guppy/heapy/Classifiers.py", line 409, in c_get_brief
__import__('traceback').print_stack()
enter
Traceback (most recent call last):
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 1702, in get_more
m = a._more
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 74, in __getattr__
return self.fam.mod.View.enter(lambda: self.fam.c_getattr(self, other))
File "/home/zhuyifei1999/guppy3/guppy/heapy/View.py", line 256, in enter
retval = func()
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 74, in <lambda>
return self.fam.mod.View.enter(lambda: self.fam.c_getattr(self, other))
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 800, in c_getattr
return self.c_getattr2(a, b)
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 803, in c_getattr2
raise AttributeError(b)
AttributeError: _more
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 1716, in get_partition
p = a._partition
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 74, in __getattr__
return self.fam.mod.View.enter(lambda: self.fam.c_getattr(self, other))
File "/home/zhuyifei1999/guppy3/guppy/heapy/View.py", line 256, in enter
retval = func()
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 74, in <lambda>
return self.fam.mod.View.enter(lambda: self.fam.c_getattr(self, other))
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 800, in c_getattr
return self.c_getattr2(a, b)
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 803, in c_getattr2
raise AttributeError(b)
AttributeError: _partition
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/zhuyifei1999/guppy3/guppy/heapy/Classifiers.py", line 412, in c_get_brief
return self.mod.summary_str(type(a.arg))(a.arg)
File "/home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py", line 2081, in str_type
return '%s.%s' % (x.__module__, x.__name__)
AttributeError: __module__
> /home/zhuyifei1999/guppy3/guppy/heapy/UniSet.py(2081)str_type()
-> return '%s.%s' % (x.__module__, x.__name__)
(Pdb) p x
<class 'CompiledFunction'>
The error is hidden because when python's descriptor's __get__
raises an AttributeError and the class has __getattr__
which also raises AttributeError, the latter exception masks the former without setting a 'During handling of the above exception, another exception occurred'
I am afraid I now very little about guppy to fully interpret this error message. Am I correct with the interpretation that guppy aborts because class 'CompiledFunction'
is missing a __module__
attribute which in turn is required to fulfill IdentitySetFamily.get_more
?
CompiledFunction
is in JAX with this in its PyTypeObject:
gef➤ p *(PyTypeObject *)93824997935776
$2 = {
ob_base = {
ob_base = {
ob_refcnt = 0x120,
ob_type = 0x7ffff7f76040 <PyType_Type>
},
ob_size = 0x0
},
tp_name = 0x7ffff4ea0652 "CompiledFunction",
tp_basicsize = 0xc8,
tp_itemsize = 0x0,
tp_dealloc = 0x7fffef683e80 <JaxCompiledFunction_tp_dealloc>,
tp_vectorcall_offset = 0x0,
tp_getattr = 0x0,
tp_setattr = 0x0,
tp_as_async = 0x0,
tp_repr = 0x7fffef685950 <JaxCompiledFunction_tp_repr>,
tp_as_number = 0x0,
tp_as_sequence = 0x0,
tp_as_mapping = 0x0,
tp_hash = 0x7ffff7c2bfc0 <_Py_HashPointer>,
tp_call = 0x7fffef69d310 <JaxCompiledFunction_tp_call>,
tp_str = 0x7ffff7c2a8a0 <object_str>,
tp_getattro = 0x7ffff7db5960 <PyObject_GenericGetAttr>,
tp_setattro = 0x7ffff7db4620 <PyObject_GenericSetAttr>,
tp_as_buffer = 0x0,
tp_flags = 0xc5200,
tp_doc = 0x0,
tp_traverse = 0x7fffef682870 <JaxCompiledFunction_tp_traverse>,
tp_clear = 0x7fffef682fb0 <JaxCompiledFunction_tp_clear>,
tp_richcompare = 0x7ffff7c3a440 <object_richcompare>,
tp_weaklistoffset = 0x18,
tp_iter = 0x0,
tp_iternext = 0x0,
tp_methods = 0x0,
tp_members = 0x0,
tp_getset = 0x7ffff62ff140 <jax::(anonymous namespace)::JaxCompiledFunction_tp_getset>,
tp_base = 0x7ffff7f76380 <PyBaseObject_Type>,
tp_dict = {'__repr__': <wrapper_descriptor at remote 0x7fffda1b56d0>, '__call__': <wrapper_descriptor at remote 0x7fffda1b5720>, '__get__': <wrapper_descriptor at remote 0x7fffda1b5770>, '__new__': <built-in method __new__ of type object at remote 0x555555ac4aa0>, '__dict__': <getset_descriptor at remote 0x7fffda1b3e80>, '__doc__': None, '__signature__': <property at remote 0x7fffda1b5900>, '_cache_miss': <property at remote 0x7fffda1b59a0>, '__getstate__': <instancemethod at remote 0x7fffda1b2ca0>, '__setstate__': <instancemethod at remote 0x7fffda1b2d00>, '_cache_size': <instancemethod at remote 0x7fffda1b7820>, '_clear_cache': <instancemethod at remote 0x7fffda1b7880>},
tp_descr_get = 0x7fffef682840 <JaxCompiledFunction_tp_descr_get>,
tp_descr_set = 0x0,
tp_dictoffset = 0x10,
tp_init = 0x7ffff7cebe80 <object_init>,
tp_alloc = 0x7ffff7cc70c0 <PyType_GenericAlloc>,
tp_new = 0x7fffef682560 <JaxCompiledFunction_tp_new>,
tp_free = 0x7ffff7c2cc00 <PyObject_GC_Del>,
tp_is_gc = 0x0,
tp_bases = (<type at remote 0x7ffff7f76380>,),
tp_mro = (<type at remote 0x555555ac4aa0>, <type at remote 0x7ffff7f76380>),
tp_cache = 0x0,
tp_subclasses = 0x0,
tp_weaklist = <weakref at remote 0x7fffda1b5810>,
tp_del = 0x0,
tp_version_tag = 0xaaa,
tp_finalize = 0x0,
tp_vectorcall = 0x0
}
It seems like the relevant source code is https://github.com/tensorflow/tensorflow/blob/ebdbd61a48f8d9e438fd02fa31eb4075b25bd278/tensorflow/compiler/xla/python/jax_jit.cc#L1202-L1293
This is a heaptype, and the descriptor for a type's __module__
is https://github.com/python/cpython/blob/5a2a65096c3ec2d37f33615f2a420d2ffcabecf2/Objects/typeobject.c#L567-L596
If it's a heaptype it'll attempt to read from the type's __dict__['__module__']
. If it doesn't exist then raise AttributeError. (If it's a non-heap type it'll attempt to parse the name for a "." and if it doesn't exist it'll return 'builtins'.)
The docs does not say whether the __module__
attribute is actually optional, but the source code of the inspect module seems to suggest so. Let me ask if they are willing to add __module__
to its dict, else I'll make an catch here.
I am afraid I now very little about guppy to fully interpret this error message. Am I correct with the interpretation that guppy aborts because
class 'CompiledFunction'
is missing a__module__
attribute which in turn is required to fulfillIdentitySetFamily.get_more
?
Yes, that is correct. For example, if I do:
diff --git a/guppy/heapy/UniSet.py b/guppy/heapy/UniSet.py
index 216c22a..9c004b7 100644
--- a/guppy/heapy/UniSet.py
+++ b/guppy/heapy/UniSet.py
@@ -2071,6 +2071,8 @@ class Summary_str:
return self.shorter_invtypes[x]
if x in self.invtypes:
return self.invtypes[x]
+ if not hasattr(x, '__module__'):
+ return f'<unknown module>.{x.__name__}'
return '%s.%s' % (x.__module__, x.__name__)
str_type._idpart_header = 'Name'
And the test:
from jax import numpy as jnp
from guppy import hpy
a = jnp.linspace(0.0, 1.0)
h = hpy() # succeeds
print(h.heap().all)
I get: https://gist.github.com/zhuyifei1999/3fc04e7e4e4ab0dbfc29b2cf4426f63a
Let me ask if they are willing to add
__module__
to its dict, else I'll make an catch here.
Awesome! Would you mind referencing the issue here as well? I would like to keep track of it as well.
The docs does not say whether the
__module__
attribute is actually optional, but the source code of the inspect module seems to suggest so.
Considering that you pointed out that the inspect module treats __module__
as optional, would it not make sense to implement a fall-back mechanism e.g. similar to the one in inspect (maybe even with __file__
) regardless of whether tensorflow adds the __module__
attribute to their CompiledFunction
?
Awesome! Would you mind referencing the issue here as well? I would like to keep track of it as well.
tensorflow/tensorflow@f21d21f 😉
Considering that you pointed out that the inspect module treats
__module__
as optional, would it not make sense to implement a fall-back mechanism e.g. similar to the one in inspect (maybe even with__file__
) regardless of whether tensorflow adds the__module__
attribute to theirCompiledFunction
?
Hmm. Would an <unknown module>.
prefix, like what I did above, make sense?
Awesome! Would you mind referencing the issue here as well? I would like to keep track of it as well.
That was fast! :)
Hmm. Would an
<unknown module>.
prefix, like what I did above, make sense?
I think <unknown module>.
would be a good placeholder for anything that does not set __module__
.
I think
<unknown module>.
would be a good placeholder for anything that does not set__module__
.
Done.