[Question] How to convert pkl to pt file?
HuaZheLei opened this issue · comments
Thanks for your excellent work!
Describe the problem
I'd like to learn how you convert pkl to pt file.
I use pt file you provide to generate images. Code is as follow:
def get_random_image(generator: Generator, truncation_psi: float, seed):
with torch.no_grad():
z = torch.from_numpy(np.random.RandomState(seed).randn(1, 512).astype('float32')).to('cuda')
if hasattr(generator.synthesis, 'input'):
m = make_transform(translate=(0, 0), angle=0)
m = np.linalg.inv(m)
generator.synthesis.input.transform.copy_(torch.from_numpy(m))
w = generator.mapping(z, None, truncation_psi=truncation_psi)
img = generator.synthesis(w, noise_mode='const')
res_image = tensor2im(img[0])
return res_image, w
And it works well.
But when I convert pkl to pt by myself, it appears several errors.
The converting code I used is as follow:
import pickle
import sys
from enum import Enum
from pathlib import Path
from typing import Optional
import torch
checkpoint_path = "pretrained_models/stylegan3-t-ffhq-1024x1024.pkl"
print(f"Loading StyleGAN3 generator from path: {checkpoint_path}")
with open(checkpoint_path, "rb") as f:
decoder = pickle.load(f)['G_ema'].cuda()
print('Loading done!')
state_dict = decoder.state_dict()
torch.save(state_dict, "pretrained_models/stylegan3-t-ffhq-1024x1024.pt")
print('Converting done!')
Then I use stylegan3-t-ffhq-1024x1024.pt to generate images. And the errors are as follow:
Loading StyleGAN3 generator from path: pretrained_models/stylegan3-t-ffhq-1024x1024.pt
Traceback (most recent call last):
File "/sam/models/stylegan3/model.py", line 61, in _load_checkpoint
self.decoder.load_state_dict(torch.load(checkpoint_path), strict=True)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
Missing key(s) in state_dict: "synthesis.L0_36_1024.weight", "synthesis.L0_36_1024.bias", "synthesis.L0_36_1024.magnitude_ema", "synthesis.L0_36_1024.up_filter", "synthesis.L0_36_1024.down_filter", "synthesis.L0_36_1024.affine.weight", "synthesis.L0_36_1024.affine.bias", "synthesis.L1_36_1024.weight", "synthesis.L1_36_1024.bias", "synthesis.L1_36_1024.magnitude_ema", "synthesis.L1_36_1024.up_filter", "synthesis.L1_36_1024.down_filter", "synthesis.L1_36_1024.affine.weight", "synthesis.L1_36_1024.affine.bias", "synthesis.L2_52_1024.weight", "synthesis.L2_52_1024.bias", "synthesis.L2_52_1024.magnitude_ema", "synthesis.L2_52_1024.up_filter", "synthesis.L2_52_1024.down_filter", "synthesis.L2_52_1024.affine.weight", "synthesis.L2_52_1024.affine.bias", "synthesis.L3_52_1024.weight", "synthesis.L3_52_1024.bias", "synthesis.L3_52_1024.magnitude_ema", "synthesis.L3_52_1024.up_filter", "synthesis.L3_52_1024.down_filter", "synthesis.L3_52_1024.affine.weight", "synthesis.L3_52_1024.affine.bias", "synthesis.L4_84_1024.weight", "synthesis.L4_84_1024.bias", "synthesis.L4_84_1024.magnitude_ema", "synthesis.L4_84_1024.up_filter", "synthesis.L4_84_1024.down_filter", "synthesis.L4_84_1024.affine.weight", "synthesis.L4_84_1024.affine.bias", "synthesis.L5_148_1024.weight", "synthesis.L5_148_1024.bias", "synthesis.L5_148_1024.magnitude_ema", "synthesis.L5_148_1024.up_filter", "synthesis.L5_148_1024.down_filter", "synthesis.L5_148_1024.affine.weight", "synthesis.L5_148_1024.affine.bias", "synthesis.L6_148_1024.weight", "synthesis.L6_148_1024.bias", "synthesis.L6_148_1024.magnitude_ema", "synthesis.L6_148_1024.up_filter", "synthesis.L6_148_1024.down_filter", "synthesis.L6_148_1024.affine.weight", "synthesis.L6_148_1024.affine.bias", "synthesis.L7_276_645.weight", "synthesis.L7_276_645.bias", "synthesis.L7_276_645.magnitude_ema", "synthesis.L7_276_645.up_filter", "synthesis.L7_276_645.down_filter", "synthesis.L7_276_645.affine.weight", "synthesis.L7_276_645.affine.bias", "synthesis.L8_276_406.weight", "synthesis.L8_276_406.bias", "synthesis.L8_276_406.magnitude_ema", "synthesis.L8_276_406.up_filter", "synthesis.L8_276_406.down_filter", "synthesis.L8_276_406.affine.weight", "synthesis.L8_276_406.affine.bias", "synthesis.L9_532_256.weight", "synthesis.L9_532_256.bias", "synthesis.L9_532_256.magnitude_ema", "synthesis.L9_532_256.up_filter", "synthesis.L9_532_256.down_filter", "synthesis.L9_532_256.affine.weight", "synthesis.L9_532_256.affine.bias", "synthesis.L10_1044_161.weight", "synthesis.L10_1044_161.bias", "synthesis.L10_1044_161.magnitude_ema", "synthesis.L10_1044_161.up_filter", "synthesis.L10_1044_161.down_filter", "synthesis.L10_1044_161.affine.weight", "synthesis.L10_1044_161.affine.bias", "synthesis.L11_1044_102.weight", "synthesis.L11_1044_102.bias", "synthesis.L11_1044_102.magnitude_ema", "synthesis.L11_1044_102.up_filter", "synthesis.L11_1044_102.down_filter", "synthesis.L11_1044_102.affine.weight", "synthesis.L11_1044_102.affine.bias", "synthesis.L12_1044_64.weight", "synthesis.L12_1044_64.bias", "synthesis.L12_1044_64.magnitude_ema", "synthesis.L12_1044_64.up_filter", "synthesis.L12_1044_64.down_filter", "synthesis.L12_1044_64.affine.weight", "synthesis.L12_1044_64.affine.bias", "synthesis.L13_1024_64.weight", "synthesis.L13_1024_64.bias", "synthesis.L13_1024_64.magnitude_ema", "synthesis.L13_1024_64.up_filter", "synthesis.L13_1024_64.down_filter", "synthesis.L13_1024_64.affine.weight", "synthesis.L13_1024_64.affine.bias".
Unexpected key(s) in state_dict: "synthesis.L0_36_512.weight", "synthesis.L0_36_512.bias", "synthesis.L0_36_512.magnitude_ema", "synthesis.L0_36_512.up_filter", "synthesis.L0_36_512.down_filter", "synthesis.L0_36_512.affine.weight", "synthesis.L0_36_512.affine.bias", "synthesis.L1_36_512.weight", "synthesis.L1_36_512.bias", "synthesis.L1_36_512.magnitude_ema", "synthesis.L1_36_512.up_filter", "synthesis.L1_36_512.down_filter", "synthesis.L1_36_512.affine.weight", "synthesis.L1_36_512.affine.bias", "synthesis.L2_52_512.weight", "synthesis.L2_52_512.bias", "synthesis.L2_52_512.magnitude_ema", "synthesis.L2_52_512.up_filter", "synthesis.L2_52_512.down_filter", "synthesis.L2_52_512.affine.weight", "synthesis.L2_52_512.affine.bias", "synthesis.L3_52_512.weight", "synthesis.L3_52_512.bias", "synthesis.L3_52_512.magnitude_ema", "synthesis.L3_52_512.up_filter", "synthesis.L3_52_512.down_filter", "synthesis.L3_52_512.affine.weight", "synthesis.L3_52_512.affine.bias", "synthesis.L4_84_512.weight", "synthesis.L4_84_512.bias", "synthesis.L4_84_512.magnitude_ema", "synthesis.L4_84_512.up_filter", "synthesis.L4_84_512.down_filter", "synthesis.L4_84_512.affine.weight", "synthesis.L4_84_512.affine.bias", "synthesis.L5_148_512.weight", "synthesis.L5_148_512.bias", "synthesis.L5_148_512.magnitude_ema", "synthesis.L5_148_512.up_filter", "synthesis.L5_148_512.down_filter", "synthesis.L5_148_512.affine.weight", "synthesis.L5_148_512.affine.bias", "synthesis.L6_148_512.weight", "synthesis.L6_148_512.bias", "synthesis.L6_148_512.magnitude_ema", "synthesis.L6_148_512.up_filter", "synthesis.L6_148_512.down_filter", "synthesis.L6_148_512.affine.weight", "synthesis.L6_148_512.affine.bias", "synthesis.L7_276_323.weight", "synthesis.L7_276_323.bias", "synthesis.L7_276_323.magnitude_ema", "synthesis.L7_276_323.up_filter", "synthesis.L7_276_323.down_filter", "synthesis.L7_276_323.affine.weight", "synthesis.L7_276_323.affine.bias", "synthesis.L8_276_203.weight", "synthesis.L8_276_203.bias", "synthesis.L8_276_203.magnitude_ema", "synthesis.L8_276_203.up_filter", "synthesis.L8_276_203.down_filter", "synthesis.L8_276_203.affine.weight", "synthesis.L8_276_203.affine.bias", "synthesis.L9_532_128.weight", "synthesis.L9_532_128.bias", "synthesis.L9_532_128.magnitude_ema", "synthesis.L9_532_128.up_filter", "synthesis.L9_532_128.down_filter", "synthesis.L9_532_128.affine.weight", "synthesis.L9_532_128.affine.bias", "synthesis.L10_1044_81.weight", "synthesis.L10_1044_81.bias", "synthesis.L10_1044_81.magnitude_ema", "synthesis.L10_1044_81.up_filter", "synthesis.L10_1044_81.down_filter", "synthesis.L10_1044_81.affine.weight", "synthesis.L10_1044_81.affine.bias", "synthesis.L11_1044_51.weight", "synthesis.L11_1044_51.bias", "synthesis.L11_1044_51.magnitude_ema", "synthesis.L11_1044_51.up_filter", "synthesis.L11_1044_51.down_filter", "synthesis.L11_1044_51.affine.weight", "synthesis.L11_1044_51.affine.bias", "synthesis.L12_1044_32.weight", "synthesis.L12_1044_32.bias", "synthesis.L12_1044_32.magnitude_ema", "synthesis.L12_1044_32.up_filter", "synthesis.L12_1044_32.down_filter", "synthesis.L12_1044_32.affine.weight", "synthesis.L12_1044_32.affine.bias", "synthesis.L13_1024_32.weight", "synthesis.L13_1024_32.bias", "synthesis.L13_1024_32.magnitude_ema", "synthesis.L13_1024_32.up_filter", "synthesis.L13_1024_32.down_filter", "synthesis.L13_1024_32.affine.weight", "synthesis.L13_1024_32.affine.bias".
size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]).
size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
size mismatch for synthesis.L14_1024_3.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]).
size mismatch for synthesis.L14_1024_3.affine.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]).
size mismatch for synthesis.L14_1024_3.affine.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "gen_images_using_pt.py", line 79, in <module>
main()
File "gen_images_using_pt.py", line 47, in main
generator = SG3Generator(checkpoint_path=args.generator_path).decoder
File "/sam/models/stylegan3/model.py", line 56, in __init__
self._load_checkpoint(checkpoint_path)
File "/sam/models/stylegan3/model.py", line 65, in _load_checkpoint
self.decoder.load_state_dict(ckpt, strict=False)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]).
size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]).
size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]).
size mismatch for synthesis.L14_1024_3.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]).
size mismatch for synthesis.L14_1024_3.affine.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]).
size mismatch for synthesis.L14_1024_3.affine.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
I am wondering if the error you got is because you are using the t config and the generators we defined are for the r config.
Does the same error happen when you try to load a pt of the r config?
Cool! I did not realize they are different. I will add t config generator in my code. Thanks again.
Hi HuaZheLei, I am trying to generate images from a .pt model, but I am not sure how to load the model. How can I load the .pt model? Thanks!!
Thanks for your excellent work!
Describe the problem
I'd like to learn how you convert pkl to pt file. I use pt file you provide to generate images. Code is as follow:
def get_random_image(generator: Generator, truncation_psi: float, seed): with torch.no_grad(): z = torch.from_numpy(np.random.RandomState(seed).randn(1, 512).astype('float32')).to('cuda') if hasattr(generator.synthesis, 'input'): m = make_transform(translate=(0, 0), angle=0) m = np.linalg.inv(m) generator.synthesis.input.transform.copy_(torch.from_numpy(m)) w = generator.mapping(z, None, truncation_psi=truncation_psi) img = generator.synthesis(w, noise_mode='const') res_image = tensor2im(img[0]) return res_image, w
And it works well. But when I convert pkl to pt by myself, it appears several errors. The converting code I used is as follow:
import pickle import sys from enum import Enum from pathlib import Path from typing import Optional import torch checkpoint_path = "pretrained_models/stylegan3-t-ffhq-1024x1024.pkl" print(f"Loading StyleGAN3 generator from path: {checkpoint_path}") with open(checkpoint_path, "rb") as f: decoder = pickle.load(f)['G_ema'].cuda() print('Loading done!') state_dict = decoder.state_dict() torch.save(state_dict, "pretrained_models/stylegan3-t-ffhq-1024x1024.pt") print('Converting done!')
Then I use stylegan3-t-ffhq-1024x1024.pt to generate images. And the errors are as follow:
Loading StyleGAN3 generator from path: pretrained_models/stylegan3-t-ffhq-1024x1024.pt Traceback (most recent call last): File "/sam/models/stylegan3/model.py", line 61, in _load_checkpoint self.decoder.load_state_dict(torch.load(checkpoint_path), strict=True) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Generator: Missing key(s) in state_dict: "synthesis.L0_36_1024.weight", "synthesis.L0_36_1024.bias", "synthesis.L0_36_1024.magnitude_ema", "synthesis.L0_36_1024.up_filter", "synthesis.L0_36_1024.down_filter", "synthesis.L0_36_1024.affine.weight", "synthesis.L0_36_1024.affine.bias", "synthesis.L1_36_1024.weight", "synthesis.L1_36_1024.bias", "synthesis.L1_36_1024.magnitude_ema", "synthesis.L1_36_1024.up_filter", "synthesis.L1_36_1024.down_filter", "synthesis.L1_36_1024.affine.weight", "synthesis.L1_36_1024.affine.bias", "synthesis.L2_52_1024.weight", "synthesis.L2_52_1024.bias", "synthesis.L2_52_1024.magnitude_ema", "synthesis.L2_52_1024.up_filter", "synthesis.L2_52_1024.down_filter", "synthesis.L2_52_1024.affine.weight", "synthesis.L2_52_1024.affine.bias", "synthesis.L3_52_1024.weight", "synthesis.L3_52_1024.bias", "synthesis.L3_52_1024.magnitude_ema", "synthesis.L3_52_1024.up_filter", "synthesis.L3_52_1024.down_filter", "synthesis.L3_52_1024.affine.weight", "synthesis.L3_52_1024.affine.bias", "synthesis.L4_84_1024.weight", "synthesis.L4_84_1024.bias", "synthesis.L4_84_1024.magnitude_ema", "synthesis.L4_84_1024.up_filter", "synthesis.L4_84_1024.down_filter", "synthesis.L4_84_1024.affine.weight", "synthesis.L4_84_1024.affine.bias", "synthesis.L5_148_1024.weight", "synthesis.L5_148_1024.bias", "synthesis.L5_148_1024.magnitude_ema", "synthesis.L5_148_1024.up_filter", "synthesis.L5_148_1024.down_filter", "synthesis.L5_148_1024.affine.weight", "synthesis.L5_148_1024.affine.bias", "synthesis.L6_148_1024.weight", "synthesis.L6_148_1024.bias", "synthesis.L6_148_1024.magnitude_ema", "synthesis.L6_148_1024.up_filter", "synthesis.L6_148_1024.down_filter", "synthesis.L6_148_1024.affine.weight", "synthesis.L6_148_1024.affine.bias", "synthesis.L7_276_645.weight", "synthesis.L7_276_645.bias", "synthesis.L7_276_645.magnitude_ema", "synthesis.L7_276_645.up_filter", "synthesis.L7_276_645.down_filter", "synthesis.L7_276_645.affine.weight", "synthesis.L7_276_645.affine.bias", "synthesis.L8_276_406.weight", "synthesis.L8_276_406.bias", "synthesis.L8_276_406.magnitude_ema", "synthesis.L8_276_406.up_filter", "synthesis.L8_276_406.down_filter", "synthesis.L8_276_406.affine.weight", "synthesis.L8_276_406.affine.bias", "synthesis.L9_532_256.weight", "synthesis.L9_532_256.bias", "synthesis.L9_532_256.magnitude_ema", "synthesis.L9_532_256.up_filter", "synthesis.L9_532_256.down_filter", "synthesis.L9_532_256.affine.weight", "synthesis.L9_532_256.affine.bias", "synthesis.L10_1044_161.weight", "synthesis.L10_1044_161.bias", "synthesis.L10_1044_161.magnitude_ema", "synthesis.L10_1044_161.up_filter", "synthesis.L10_1044_161.down_filter", "synthesis.L10_1044_161.affine.weight", "synthesis.L10_1044_161.affine.bias", "synthesis.L11_1044_102.weight", "synthesis.L11_1044_102.bias", "synthesis.L11_1044_102.magnitude_ema", "synthesis.L11_1044_102.up_filter", "synthesis.L11_1044_102.down_filter", "synthesis.L11_1044_102.affine.weight", "synthesis.L11_1044_102.affine.bias", "synthesis.L12_1044_64.weight", "synthesis.L12_1044_64.bias", "synthesis.L12_1044_64.magnitude_ema", "synthesis.L12_1044_64.up_filter", "synthesis.L12_1044_64.down_filter", "synthesis.L12_1044_64.affine.weight", "synthesis.L12_1044_64.affine.bias", "synthesis.L13_1024_64.weight", "synthesis.L13_1024_64.bias", "synthesis.L13_1024_64.magnitude_ema", "synthesis.L13_1024_64.up_filter", "synthesis.L13_1024_64.down_filter", "synthesis.L13_1024_64.affine.weight", "synthesis.L13_1024_64.affine.bias". Unexpected key(s) in state_dict: "synthesis.L0_36_512.weight", "synthesis.L0_36_512.bias", "synthesis.L0_36_512.magnitude_ema", "synthesis.L0_36_512.up_filter", "synthesis.L0_36_512.down_filter", "synthesis.L0_36_512.affine.weight", "synthesis.L0_36_512.affine.bias", "synthesis.L1_36_512.weight", "synthesis.L1_36_512.bias", "synthesis.L1_36_512.magnitude_ema", "synthesis.L1_36_512.up_filter", "synthesis.L1_36_512.down_filter", "synthesis.L1_36_512.affine.weight", "synthesis.L1_36_512.affine.bias", "synthesis.L2_52_512.weight", "synthesis.L2_52_512.bias", "synthesis.L2_52_512.magnitude_ema", "synthesis.L2_52_512.up_filter", "synthesis.L2_52_512.down_filter", "synthesis.L2_52_512.affine.weight", "synthesis.L2_52_512.affine.bias", "synthesis.L3_52_512.weight", "synthesis.L3_52_512.bias", "synthesis.L3_52_512.magnitude_ema", "synthesis.L3_52_512.up_filter", "synthesis.L3_52_512.down_filter", "synthesis.L3_52_512.affine.weight", "synthesis.L3_52_512.affine.bias", "synthesis.L4_84_512.weight", "synthesis.L4_84_512.bias", "synthesis.L4_84_512.magnitude_ema", "synthesis.L4_84_512.up_filter", "synthesis.L4_84_512.down_filter", "synthesis.L4_84_512.affine.weight", "synthesis.L4_84_512.affine.bias", "synthesis.L5_148_512.weight", "synthesis.L5_148_512.bias", "synthesis.L5_148_512.magnitude_ema", "synthesis.L5_148_512.up_filter", "synthesis.L5_148_512.down_filter", "synthesis.L5_148_512.affine.weight", "synthesis.L5_148_512.affine.bias", "synthesis.L6_148_512.weight", "synthesis.L6_148_512.bias", "synthesis.L6_148_512.magnitude_ema", "synthesis.L6_148_512.up_filter", "synthesis.L6_148_512.down_filter", "synthesis.L6_148_512.affine.weight", "synthesis.L6_148_512.affine.bias", "synthesis.L7_276_323.weight", "synthesis.L7_276_323.bias", "synthesis.L7_276_323.magnitude_ema", "synthesis.L7_276_323.up_filter", "synthesis.L7_276_323.down_filter", "synthesis.L7_276_323.affine.weight", "synthesis.L7_276_323.affine.bias", "synthesis.L8_276_203.weight", "synthesis.L8_276_203.bias", "synthesis.L8_276_203.magnitude_ema", "synthesis.L8_276_203.up_filter", "synthesis.L8_276_203.down_filter", "synthesis.L8_276_203.affine.weight", "synthesis.L8_276_203.affine.bias", "synthesis.L9_532_128.weight", "synthesis.L9_532_128.bias", "synthesis.L9_532_128.magnitude_ema", "synthesis.L9_532_128.up_filter", "synthesis.L9_532_128.down_filter", "synthesis.L9_532_128.affine.weight", "synthesis.L9_532_128.affine.bias", "synthesis.L10_1044_81.weight", "synthesis.L10_1044_81.bias", "synthesis.L10_1044_81.magnitude_ema", "synthesis.L10_1044_81.up_filter", "synthesis.L10_1044_81.down_filter", "synthesis.L10_1044_81.affine.weight", "synthesis.L10_1044_81.affine.bias", "synthesis.L11_1044_51.weight", "synthesis.L11_1044_51.bias", "synthesis.L11_1044_51.magnitude_ema", "synthesis.L11_1044_51.up_filter", "synthesis.L11_1044_51.down_filter", "synthesis.L11_1044_51.affine.weight", "synthesis.L11_1044_51.affine.bias", "synthesis.L12_1044_32.weight", "synthesis.L12_1044_32.bias", "synthesis.L12_1044_32.magnitude_ema", "synthesis.L12_1044_32.up_filter", "synthesis.L12_1044_32.down_filter", "synthesis.L12_1044_32.affine.weight", "synthesis.L12_1044_32.affine.bias", "synthesis.L13_1024_32.weight", "synthesis.L13_1024_32.bias", "synthesis.L13_1024_32.magnitude_ema", "synthesis.L13_1024_32.up_filter", "synthesis.L13_1024_32.down_filter", "synthesis.L13_1024_32.affine.weight", "synthesis.L13_1024_32.affine.bias". size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]). size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]). size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]). size mismatch for synthesis.L14_1024_3.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]). size mismatch for synthesis.L14_1024_3.affine.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]). size mismatch for synthesis.L14_1024_3.affine.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]). During handling of the above exception, another exception occurred: Traceback (most recent call last): File "gen_images_using_pt.py", line 79, in <module> main() File "gen_images_using_pt.py", line 47, in main generator = SG3Generator(checkpoint_path=args.generator_path).decoder File "/sam/models/stylegan3/model.py", line 56, in __init__ self._load_checkpoint(checkpoint_path) File "/sam/models/stylegan3/model.py", line 65, in _load_checkpoint self.decoder.load_state_dict(ckpt, strict=False) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Generator: size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]). size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]). size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]). size mismatch for synthesis.L14_1024_3.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]). size mismatch for synthesis.L14_1024_3.affine.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]). size mismatch for synthesis.L14_1024_3.affine.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
Hi HuaZheLei, I am trying to generate images from a .pt model, but I am not sure how to load the model. How can I load the .pt model? Thanks!!
Thanks for your excellent work!
Describe the problem
I'd like to learn how you convert pkl to pt file. I use pt file you provide to generate images. Code is as follow:
def get_random_image(generator: Generator, truncation_psi: float, seed): with torch.no_grad(): z = torch.from_numpy(np.random.RandomState(seed).randn(1, 512).astype('float32')).to('cuda') if hasattr(generator.synthesis, 'input'): m = make_transform(translate=(0, 0), angle=0) m = np.linalg.inv(m) generator.synthesis.input.transform.copy_(torch.from_numpy(m)) w = generator.mapping(z, None, truncation_psi=truncation_psi) img = generator.synthesis(w, noise_mode='const') res_image = tensor2im(img[0]) return res_image, w
And it works well. But when I convert pkl to pt by myself, it appears several errors. The converting code I used is as follow:
import pickle import sys from enum import Enum from pathlib import Path from typing import Optional import torch checkpoint_path = "pretrained_models/stylegan3-t-ffhq-1024x1024.pkl" print(f"Loading StyleGAN3 generator from path: {checkpoint_path}") with open(checkpoint_path, "rb") as f: decoder = pickle.load(f)['G_ema'].cuda() print('Loading done!') state_dict = decoder.state_dict() torch.save(state_dict, "pretrained_models/stylegan3-t-ffhq-1024x1024.pt") print('Converting done!')
Then I use stylegan3-t-ffhq-1024x1024.pt to generate images. And the errors are as follow:
Loading StyleGAN3 generator from path: pretrained_models/stylegan3-t-ffhq-1024x1024.pt Traceback (most recent call last): File "/sam/models/stylegan3/model.py", line 61, in _load_checkpoint self.decoder.load_state_dict(torch.load(checkpoint_path), strict=True) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Generator: Missing key(s) in state_dict: "synthesis.L0_36_1024.weight", "synthesis.L0_36_1024.bias", "synthesis.L0_36_1024.magnitude_ema", "synthesis.L0_36_1024.up_filter", "synthesis.L0_36_1024.down_filter", "synthesis.L0_36_1024.affine.weight", "synthesis.L0_36_1024.affine.bias", "synthesis.L1_36_1024.weight", "synthesis.L1_36_1024.bias", "synthesis.L1_36_1024.magnitude_ema", "synthesis.L1_36_1024.up_filter", "synthesis.L1_36_1024.down_filter", "synthesis.L1_36_1024.affine.weight", "synthesis.L1_36_1024.affine.bias", "synthesis.L2_52_1024.weight", "synthesis.L2_52_1024.bias", "synthesis.L2_52_1024.magnitude_ema", "synthesis.L2_52_1024.up_filter", "synthesis.L2_52_1024.down_filter", "synthesis.L2_52_1024.affine.weight", "synthesis.L2_52_1024.affine.bias", "synthesis.L3_52_1024.weight", "synthesis.L3_52_1024.bias", "synthesis.L3_52_1024.magnitude_ema", "synthesis.L3_52_1024.up_filter", "synthesis.L3_52_1024.down_filter", "synthesis.L3_52_1024.affine.weight", "synthesis.L3_52_1024.affine.bias", "synthesis.L4_84_1024.weight", "synthesis.L4_84_1024.bias", "synthesis.L4_84_1024.magnitude_ema", "synthesis.L4_84_1024.up_filter", "synthesis.L4_84_1024.down_filter", "synthesis.L4_84_1024.affine.weight", "synthesis.L4_84_1024.affine.bias", "synthesis.L5_148_1024.weight", "synthesis.L5_148_1024.bias", "synthesis.L5_148_1024.magnitude_ema", "synthesis.L5_148_1024.up_filter", "synthesis.L5_148_1024.down_filter", "synthesis.L5_148_1024.affine.weight", "synthesis.L5_148_1024.affine.bias", "synthesis.L6_148_1024.weight", "synthesis.L6_148_1024.bias", "synthesis.L6_148_1024.magnitude_ema", "synthesis.L6_148_1024.up_filter", "synthesis.L6_148_1024.down_filter", "synthesis.L6_148_1024.affine.weight", "synthesis.L6_148_1024.affine.bias", "synthesis.L7_276_645.weight", "synthesis.L7_276_645.bias", "synthesis.L7_276_645.magnitude_ema", "synthesis.L7_276_645.up_filter", "synthesis.L7_276_645.down_filter", "synthesis.L7_276_645.affine.weight", "synthesis.L7_276_645.affine.bias", "synthesis.L8_276_406.weight", "synthesis.L8_276_406.bias", "synthesis.L8_276_406.magnitude_ema", "synthesis.L8_276_406.up_filter", "synthesis.L8_276_406.down_filter", "synthesis.L8_276_406.affine.weight", "synthesis.L8_276_406.affine.bias", "synthesis.L9_532_256.weight", "synthesis.L9_532_256.bias", "synthesis.L9_532_256.magnitude_ema", "synthesis.L9_532_256.up_filter", "synthesis.L9_532_256.down_filter", "synthesis.L9_532_256.affine.weight", "synthesis.L9_532_256.affine.bias", "synthesis.L10_1044_161.weight", "synthesis.L10_1044_161.bias", "synthesis.L10_1044_161.magnitude_ema", "synthesis.L10_1044_161.up_filter", "synthesis.L10_1044_161.down_filter", "synthesis.L10_1044_161.affine.weight", "synthesis.L10_1044_161.affine.bias", "synthesis.L11_1044_102.weight", "synthesis.L11_1044_102.bias", "synthesis.L11_1044_102.magnitude_ema", "synthesis.L11_1044_102.up_filter", "synthesis.L11_1044_102.down_filter", "synthesis.L11_1044_102.affine.weight", "synthesis.L11_1044_102.affine.bias", "synthesis.L12_1044_64.weight", "synthesis.L12_1044_64.bias", "synthesis.L12_1044_64.magnitude_ema", "synthesis.L12_1044_64.up_filter", "synthesis.L12_1044_64.down_filter", "synthesis.L12_1044_64.affine.weight", "synthesis.L12_1044_64.affine.bias", "synthesis.L13_1024_64.weight", "synthesis.L13_1024_64.bias", "synthesis.L13_1024_64.magnitude_ema", "synthesis.L13_1024_64.up_filter", "synthesis.L13_1024_64.down_filter", "synthesis.L13_1024_64.affine.weight", "synthesis.L13_1024_64.affine.bias". Unexpected key(s) in state_dict: "synthesis.L0_36_512.weight", "synthesis.L0_36_512.bias", "synthesis.L0_36_512.magnitude_ema", "synthesis.L0_36_512.up_filter", "synthesis.L0_36_512.down_filter", "synthesis.L0_36_512.affine.weight", "synthesis.L0_36_512.affine.bias", "synthesis.L1_36_512.weight", "synthesis.L1_36_512.bias", "synthesis.L1_36_512.magnitude_ema", "synthesis.L1_36_512.up_filter", "synthesis.L1_36_512.down_filter", "synthesis.L1_36_512.affine.weight", "synthesis.L1_36_512.affine.bias", "synthesis.L2_52_512.weight", "synthesis.L2_52_512.bias", "synthesis.L2_52_512.magnitude_ema", "synthesis.L2_52_512.up_filter", "synthesis.L2_52_512.down_filter", "synthesis.L2_52_512.affine.weight", "synthesis.L2_52_512.affine.bias", "synthesis.L3_52_512.weight", "synthesis.L3_52_512.bias", "synthesis.L3_52_512.magnitude_ema", "synthesis.L3_52_512.up_filter", "synthesis.L3_52_512.down_filter", "synthesis.L3_52_512.affine.weight", "synthesis.L3_52_512.affine.bias", "synthesis.L4_84_512.weight", "synthesis.L4_84_512.bias", "synthesis.L4_84_512.magnitude_ema", "synthesis.L4_84_512.up_filter", "synthesis.L4_84_512.down_filter", "synthesis.L4_84_512.affine.weight", "synthesis.L4_84_512.affine.bias", "synthesis.L5_148_512.weight", "synthesis.L5_148_512.bias", "synthesis.L5_148_512.magnitude_ema", "synthesis.L5_148_512.up_filter", "synthesis.L5_148_512.down_filter", "synthesis.L5_148_512.affine.weight", "synthesis.L5_148_512.affine.bias", "synthesis.L6_148_512.weight", "synthesis.L6_148_512.bias", "synthesis.L6_148_512.magnitude_ema", "synthesis.L6_148_512.up_filter", "synthesis.L6_148_512.down_filter", "synthesis.L6_148_512.affine.weight", "synthesis.L6_148_512.affine.bias", "synthesis.L7_276_323.weight", "synthesis.L7_276_323.bias", "synthesis.L7_276_323.magnitude_ema", "synthesis.L7_276_323.up_filter", "synthesis.L7_276_323.down_filter", "synthesis.L7_276_323.affine.weight", "synthesis.L7_276_323.affine.bias", "synthesis.L8_276_203.weight", "synthesis.L8_276_203.bias", "synthesis.L8_276_203.magnitude_ema", "synthesis.L8_276_203.up_filter", "synthesis.L8_276_203.down_filter", "synthesis.L8_276_203.affine.weight", "synthesis.L8_276_203.affine.bias", "synthesis.L9_532_128.weight", "synthesis.L9_532_128.bias", "synthesis.L9_532_128.magnitude_ema", "synthesis.L9_532_128.up_filter", "synthesis.L9_532_128.down_filter", "synthesis.L9_532_128.affine.weight", "synthesis.L9_532_128.affine.bias", "synthesis.L10_1044_81.weight", "synthesis.L10_1044_81.bias", "synthesis.L10_1044_81.magnitude_ema", "synthesis.L10_1044_81.up_filter", "synthesis.L10_1044_81.down_filter", "synthesis.L10_1044_81.affine.weight", "synthesis.L10_1044_81.affine.bias", "synthesis.L11_1044_51.weight", "synthesis.L11_1044_51.bias", "synthesis.L11_1044_51.magnitude_ema", "synthesis.L11_1044_51.up_filter", "synthesis.L11_1044_51.down_filter", "synthesis.L11_1044_51.affine.weight", "synthesis.L11_1044_51.affine.bias", "synthesis.L12_1044_32.weight", "synthesis.L12_1044_32.bias", "synthesis.L12_1044_32.magnitude_ema", "synthesis.L12_1044_32.up_filter", "synthesis.L12_1044_32.down_filter", "synthesis.L12_1044_32.affine.weight", "synthesis.L12_1044_32.affine.bias", "synthesis.L13_1024_32.weight", "synthesis.L13_1024_32.bias", "synthesis.L13_1024_32.magnitude_ema", "synthesis.L13_1024_32.up_filter", "synthesis.L13_1024_32.down_filter", "synthesis.L13_1024_32.affine.weight", "synthesis.L13_1024_32.affine.bias". size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]). size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]). size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]). size mismatch for synthesis.L14_1024_3.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]). size mismatch for synthesis.L14_1024_3.affine.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]). size mismatch for synthesis.L14_1024_3.affine.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]). During handling of the above exception, another exception occurred: Traceback (most recent call last): File "gen_images_using_pt.py", line 79, in <module> main() File "gen_images_using_pt.py", line 47, in main generator = SG3Generator(checkpoint_path=args.generator_path).decoder File "/sam/models/stylegan3/model.py", line 56, in __init__ self._load_checkpoint(checkpoint_path) File "/sam/models/stylegan3/model.py", line 65, in _load_checkpoint self.decoder.load_state_dict(ckpt, strict=False) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1223, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Generator: size mismatch for synthesis.input.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([1024, 1024]). size mismatch for synthesis.input.freqs: copying a param with shape torch.Size([512, 2]) from checkpoint, the shape in current model is torch.Size([1024, 2]). size mismatch for synthesis.input.phases: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1024]). size mismatch for synthesis.L14_1024_3.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 64, 1, 1]). size mismatch for synthesis.L14_1024_3.affine.weight: copying a param with shape torch.Size([32, 512]) from checkpoint, the shape in current model is torch.Size([64, 512]). size mismatch for synthesis.L14_1024_3.affine.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([64]).
Hi, I will share my code here.
import os
import argparse
from typing import Tuple, List, Union
import numpy as np
import torch
from models.stylegan3.model import SG3Generator
from models.stylegan3.networks_stylegan3 import Generator
from utils.common import tensor2im
def make_transform(translate: Tuple[float, float], angle: float):
m = np.eye(3)
s = np.sin(angle / 360.0 * np.pi * 2)
c = np.cos(angle / 360.0 * np.pi * 2)
m[0][0] = c
m[0][1] = s
m[0][2] = translate[0]
m[1][0] = -s
m[1][1] = c
m[1][2] = translate[1]
return m
def main():
args = parse_args()
save_dir = args.save_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir)
generator = SG3Generator(checkpoint_path=args.generator_path).decoder
for i in range(args.image_numbers):
print('Generating image for seed %d (%d/%d) ...' % (i, i, args.image_numbers))
image, latent = get_random_image(generator, truncation_psi=args.truncation_psi, seed=i)
image.save(os.path.join(save_dir, 'seed' + str(i).zfill(4) + '.png'))
def get_random_image(generator: Generator, truncation_psi: float, seed):
with torch.no_grad():
z = torch.from_numpy(np.random.RandomState(seed).randn(1, 512).astype('float32')).to('cuda')
if hasattr(generator.synthesis, 'input'):
m = make_transform(translate=(0, 0), angle=0)
m = np.linalg.inv(m)
generator.synthesis.input.transform.copy_(torch.from_numpy(m))
w = generator.mapping(z, None, truncation_psi=truncation_psi)
img = generator.synthesis(w, noise_mode='const')
res_image = tensor2im(img[0])
return res_image, w
Hope it helpful.
Thank you very much!!!.. you save my day!