google / fedjax

FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Full EMNIST example does not exhibit parallelization

gaseln opened this issue · comments

Hi! I am facing an issue with parallelizing the base code provided by the developers.

  • My local workstation contains two GPUs.
  • I installed FedJax in a conda environment
  • I downloaded "emnist_fed_avg.py" file from the folder "examples", deleted the "fedjax.training.set_tf_cpu_only()" line and replaced fed_avg.federated_averaging to fedjax.algorithms.fed_avg.federated_averaging on line 61
  • Having activated the conda environment, I ran the file with python emnist_fed_avg.py. The file runs correctly and prints the expected output (round nums and train/test metrics on each 10th round)
  • The nvidia-smi command shows zero percent utilization and almost zero memory usage on one of the GPUs (and ~40% utilization/maximum memory usage on another node)

Any ideas what I am doing wrong?

The solution was to replace fedjax.training.set_tf_cpu_only() by fedjax.set_for_each_client_backend('pmap').

Glad you figured this out on your own. There is a tutorial section on "for_each_client" that may also be of interest to you.