JingyunLiang / SwinIR

SwinIR: Image Restoration Using Swin Transformer (official repository)

Home Page:https://arxiv.org/abs/2108.10257

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Can gan be finetuned on own dataset?

betterftr opened this issue · comments

When I try to set pretrained models (003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth) paths in KAIR's train file;

, "path": {
"root": "superresolution" // "denoising" | "superresolution" | "dejpeg"
, "pretrained_netG": null // path of pretrained model
, "pretrained_netD": null // path of pretrained model
, "pretrained_netE": null // path of pretrained model
}

it starts to train from scratch anyway. And when I copy from model_zoo right into /superresolution/swinir_sr_realworld_x4_gan/models/ thats not working either.

Thanks for your question. See the cszn's answer.

You should change the file name 003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth into 5000_G.pth, and then put it into /superresolution/swinir_sr_realworld_x4_gan/models

Yes, i tried it gives a long error ending with : Unexpected key(s) in state_dict: "params_ema".

Loading model for G [D:/AI/KAIR/superresolution\swinir_sr_realworld_x4_gan\models\5000_G.pth] ...
Traceback (most recent call last):
File "main_train_psnr.py", line 248, in
main()
File "main_train_psnr.py", line 156, in main
model.init_train()
File "D:\AI\KAIR\models\model_gan.py", line 40, in init_train
self.load() # load model
File "D:\AI\KAIR\models\model_gan.py", line 56, in load
self.load_network(load_path_G, self.netG, strict=self.opt_train['G_param_strict'])
File "D:\AI\KAIR\models\model_base.py", line 160, in load_network
network.load_state_dict(torch.load(load_path), strict=strict)
File "D:\CONDA\envs\real\lib\site-packages\torch\nn\modules\module.py", line 1052, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for SwinIR:

Ref to #20 (comment)

We will try to fix this inconsistency later.

Okay, can you tell me what line to change? The referred solution is for main_test_swinir.py not for training code from KAIR

Change

network.load_state_dict(torch.load(load_path), strict=strict)

to

network.load_state_dict(torch.load(load_path)['params_ema'], strict=strict)

at https://github.com/cszn/KAIR/blob/7d70f91bb7c03d8795a6bed29ee17b1c6b834e4e/models/model_base.py#L160.

Note: this is just a temporary solution (may result in problems for other experiments). We will solve this problem with a better solution later.