Question about how to reload trained models
butcher1226 opened this issue · comments
Could you please offer the code for reloading trained models (saved as .pkl files)?
Since I don't know your use case, I can only give a simple example below.
Assume you have saved a server
object, for example, using the code here. As you can see, I only used the torch.save
.
Could you please try the following code snippet? Since I did not have the access to the server anymore, I cannot test it on my end.
from src.utils import load_from_pkl, save_to_pkl
import pickle
class CPU_Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == 'torch.storage' and name == '_load_from_bytes':
return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
else: return super().find_class(module, name)
def load_from_pkl(path):
with open(path, 'rb') as file:
data = CPU_Unpickler(file).load()
return data
# assume that you have
server_object = load_from_pkl("your_pkl_file")
# the global model
server_object.server_model_state_dict
# once you have the model saved as the state_dict, you can load the torch model using `torch.load`.
To get the client side models, I think you may need to starting with the saved final the global model, and fine-tuning it with the local data for a few steps. I didn't think I saved the local models explictly.