google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

flax.core.FrozenDict copy broken when the new dictionary contains some names

marksandler2 opened this issue · comments

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

Problem you have encountered:

Adding a dictionary which contains 'cls' key fails,
image

What you expected to happen:

expected to update the value of 'cls' key.

Logs, error messages, etc:

Steps to reproduce:

flax.core.FrozenDict({}).copy({'cls': 'abc'})

One way to workaround this is to manually create concatenated FrozenDict instead of using copy.

flax.core.FrozenDict({**flax.core.FrozenDict({'def': '123', 'cls': 22}), **{'cls': 'abc'}})

Thanks for catching this bug!

Your code snippet on line 98 resolves to: return FrozenDict(self, cls='abc'). This will invoke __new__ of a superclass in the Typing library, which has cls as its first argument. So since you also pass it as a kwarg, the interpreter complains that you passed the same argument twice.

It seems dangerous to me that the current code just expands all key-value pairs in add_or_replace to kwargs to the constructor, since any reserved word could introduce bugs. The safest way seems to me to explicitly wrap the two dicts in a new dict, i.e. replace line 98 with:

return type(self)({**self, **unfreeze(add_or_replace)})

@jheek WDYT?

commented

Oh my Python! :)

@marcvanzee your solution looks like the easiest workaround