hanglearning / VBFL

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

device's local accuracy

Hongchenglong opened this issue · comments

In main.py line 817 device.accuracy_this_round = device.validate_model_weights(), I find that validate_model_weights() actually gets global accuracy, not a device's local accuracy. It should be self.validate_model_weights(self.net.state_dict()).

So In your logs accuracy_comm_*.txt, all devices' accuracy is global accuracy.

Sorry for a late reply. Was a bit busy last week.

The "global accuracy" in this code base actually could mean two things

  1. global_accuracy = global model tested on the entire MNIST test set (like the accuracy reported in Figure 1 and 5 in the paper)
    This corresponds to using -st 1 or --shard_test_data 1 in main.py. In this case, test_dl of each client will contain the entire MNIST test set (Device.py, lines 1414 - 1417).
  2. global_accuracy = global model tested on the local test set of a device.
    This corresponds to using -st 0 or --shard_test_data 0 (the default value) in main.py. In this case, test_dl of each client will contain a partial non-overlapping MNIST test set (Device.py, lines 1419 - 1423, 1442 - 1451).

But since you mentioned that the code should be self.validate_model_weights(self.net.state_dict()), I think these two situations are not your concern. My understanding is that you want to report the accuracy (i.e., the local accuracy you meantioned) = a client's trained local model tested on a client's test_dl.

While it is interesting to see how a client's final local model performs on its own dataset (especially in the situation when the a client's training data is polluted, or its distibution is so different from the others (non-iid), etc,), my humble opinion is that the goal of federated learning is to jointly learn a global model, with the expectation that the global model surpasses the local model from the accuracy's perspective. Therefore, a local model is actually an intermediate product of FL, and reporting individual client's local model's accuracy may not be so significant.

However, if reporting clients' local models' accuracies is significant in your research, definitely it should be self.validate_model_weights(self.net.state_dict()).

Please let me know if this makes sense. This is a good questiona and thanks for pointing that out!

Since you asked this question, I also wish to highlight anthoer two things for the other readers

  1. In my paper, I used --shard_test_data 0. In this case, I could report a single accuracy metric across all the devices to put in a single figure (i.e., Figure 1 and 5 in the paper) for easy explanantion of the VBFL's validation algorithm against the poisoned model, and I figured that a global accuracy is a viable solution. In this case, the values inaccuracy_comm_*.txt from all the devices should be the same, and I think I created Figure 1 and 5 from a random device's accuracy_comm_*.txt.

  2. However, if using --shard_test_data 0, the model validation mechanism defined in validator_update_model_by_one_epoch_and_validate_local_accuracy() of Device.py makes every validator uses the same validation set (i.e., the entire MNIST test set), despite that the one-epoch validator model still uses the validator's own dataset. I later figured this may be a bad practice, since
    a. this may not be practical
    b. some clients may intentionally train on this reference validation set instead of a local data set to always get high accuracy to fraud rewards
    But still, please note that the one-epoch validator model still uses the validator's own training data (Device.py, lines 1259 - 1265) and it is different across the validators.

  3. I later figured out the root cause for the sharp accuracy drop of a client's local model in Figure 3. This is because each client has been assigned with only 2 labels out of the 10 labels from the MNIST dataset. I adopted the vanilla FL code from https://github.com/WHDY/FedAvg and by default, the Non-IID sharding mechanism in that code base distributes 2 labels to each client. I was immature in research and coding at that time and failed to notice this feature. I later changed its Non-IID sharding to assign 5 labels to each client, and then the local accuracy drops to about 50%. If I shard all 10 labels, then minor drop would occur. Therefore, two things can be inferred
    a. Only one-epoch of local training from a global model can be very sensitive to a client's local dataset. The produced intermediate local model could lose the global model's generalization ability
    b. The validation mechanism proposed in this paper has a limitation - it may only work if
    (1) each client has the same number of labels and the sample count of each label is more or less the same (verified)
    (2) or, I guess each client just simply has more or less the same total sample count (not verified)