google / fedjax

FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Support for haiku models with non-trainable state

marcociccone opened this issue · comments

Hi!
congrats on this great library! I've started using it a few days ago and I love it!

Is there any way to use a haiku model with a non-trainable state (e.g. to use batch norm)?
I didn't find any nontrivial way, but maybe I'm missing something.

Thanks a lot for your help!

Thanks for the feedback. Currently, we do not support using a haiku model with a non-trainable state. Tracking the state across federated rounds is nontrivial and we could not find a good use case for it. If you share your use-case, we are happy to see if there is an alternate way to implement it in fedjax.

Thanks @stheertha! I agree that tracking statistics is nontrivial in FL. For the moment I've overcome the issue by replacing batchnorm with groupnorm and it seems to be working fine. There might be cases in which you may want to use only client specific stats though.