google / orbax

Orbax provides common utility libraries for JAX users.

Home Page:https://orbax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Orbax cannot save numpy array with dtype=np.object_

amifalk opened this issue · comments

Are there any plans to support this in orbax? Tensorstore can interpret strings: https://google.github.io/tensorstore/python/api/tensorstore.string.html

I realize I can pull out the object arrays into a json file and then stitch them back together when I load things in, but it's not ergonomic given that they are logically connected in my workflow.

import os
import orbax.checkpoint as ocp

test = {'a': np.array([True, False, np.nan], dtype=np.object_),
        'b': np.array(['x', 'y', 'z'], dtype=np.object_)}

ckptr = ocp.StandardCheckpointer()
ckptr.save(f'{os.getcwd()}/test', test)

ValueError: Error parsing object member "dtype": Unsupported data type: "object" [source locations='tensorstore/internal/json_binding/json_binding.h:383\ntensorstore/internal/json_binding/json_binding.h:524\ntensorstore/internal/json_binding/json_binding.h:861\ntensorstore/internal/json_binding/json_binding.h:825']

I recognize that it's not the most convenient solution, but you could also implement a TypeHandler to deal with this. Would be a relatively simple override of the existing NumpyHandler.