A simple ckpt to pt model convertor
jiaxue-ai opened this issue · comments
Hi,
I wrote a simple ckpt to pt model convertor, in case anyone needs it
import os
import yaml
from saicinpainting.training.trainers import load_checkpoint
from omegaconf import OmegaConf
lama_model_path = '/LaMa_models/lama-places/lama-fourier/'
train_config_path = os.path.join(lama_model_path, 'config.yaml')
with open(train_config_path, 'r') as f:
train_config = OmegaConf.create(yaml.safe_load(f))
train_config.training_model.predict_only = True
train_config.visualizer.kind = 'noop'
checkpoint_path = os.path.join(lama_model_path,
'models',
'best.ckpt')
model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
model.freeze()
with torch.no_grad():
typical_input = torch.zeros([1, 4, 512, 512])
# print(model.generator(typical_input).shape)
traced_cell = torch.jit.trace(model.generator, (typical_input))
torch.jit.save(traced_cell, os.path.join(lama_model_path, 'lama-model-best.pt'))