Full EMNIST example does not exhibit parallelization
gaseln opened this issue · comments
Elnur Gasanov commented
Hi! I am facing an issue with parallelizing the base code provided by the developers.
- My local workstation contains two GPUs.
- I installed FedJax in a conda environment
- I downloaded "emnist_fed_avg.py" file from the folder "examples", deleted the "fedjax.training.set_tf_cpu_only()" line and replaced
fed_avg.federated_averaging
tofedjax.algorithms.fed_avg.federated_averaging
on line 61 - Having activated the conda environment, I ran the file with
python emnist_fed_avg.py
. The file runs correctly and prints the expected output (round nums and train/test metrics on each 10th round) - The
nvidia-smi
command shows zero percent utilization and almost zero memory usage on one of the GPUs (and ~40% utilization/maximum memory usage on another node)
Any ideas what I am doing wrong?
Elnur Gasanov commented
The solution was to replace fedjax.training.set_tf_cpu_only()
by fedjax.set_for_each_client_backend('pmap')
.
Wu, Ke commented
Glad you figured this out on your own. There is a tutorial section on "for_each_client" that may also be of interest to you.