Yutong-Dai / FedNH

Code release for Tackling Data Heterogeneity in Federated Learning with Class Prototypes appeared on AAAI2023.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question about how to reload trained models

butcher1226 opened this issue · comments

Could you please offer the code for reloading trained models (saved as .pkl files)?

Since I don't know your use case, I can only give a simple example below.
Assume you have saved a server object, for example, using the code here. As you can see, I only used the torch.save.

Could you please try the following code snippet? Since I did not have the access to the server anymore, I cannot test it on my end.

from src.utils import load_from_pkl, save_to_pkl
import pickle
class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else: return super().find_class(module, name)
    
def load_from_pkl(path):
    with open(path, 'rb') as file:
        data = CPU_Unpickler(file).load()
    return data   

# assume that you have 
server_object = load_from_pkl("your_pkl_file")
# the global model 
server_object.server_model_state_dict

# once you have the model saved as the state_dict, you can load the  torch model using `torch.load`.

To get the client side models, I think you may need to starting with the saved final the global model, and fine-tuning it with the local data for a few steps. I didn't think I saved the local models explictly.