XMUOFGJK / MY-SR

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

MY-SR

# #my train

CUDA_VISIBLE_DEVICES=6,7 python main.py --n_GPUs 2 --model resDBPN_2 --scale 2 --save resDBPN2_v5_x2_bs16_120_ps112_k6_res3_L1_RDB_SE --ext sep --base_filter 120 --n_resblock 3 \

--lr 1e-4 --batch_size 16 --patch_size 112 --epochs 1000 --loss 1*L1 --print_every 20 --test_every 300 --data_test Set5 --save_results #--reset --weight_decay 1e-4 --chop

small_model #my train

CUDA_VISIBLE_DEVICES=0,1 python main.py --n_GPUs 2 --model resDBPN_2 --scale 2 --save resDBPN2_small_x2_bs24_64_ps128_k6_res3_L1_SE_stage4_dilated_2 --ext sep --num_stages 4 --base_filter 64 --n_resblock 3 \

--lr 1e-4 --batch_size 24 --patch_size 128 --epochs 1000 --loss 1*L1 --print_every 20 --test_every 300 --data_test Set5 --save_results #--reset --weight_decay 1e-4 --chop

##GAN

CUDA_VISIBLE_DEVICES=2,3,4 python main.py --n_GPUs 3 --model resDBPN_2 --scale 2 --save resDBPN2_v4_x2_bs20_96_ps128_k10_res3_GAN_2 --ext sep --base_filter 96 --n_resblock 3 \

--lr 1e-4 --loss 6VGG54+0.1GAN --batch_size 16 --patch_size 128 --epochs 1000 --print_every 20 --test_every 300 --data_test Set5 --save_results #--reset --weight_decay 1e-4 --chop --loss 1*L1

#my test

CUDA_VISIBLE_DEVICES=4 python main.py --n_GPUs 1 --model resDBPN_2 --data_test Demo --scale 4 --base_filter 96 --n_resblock 5 --pre_train ../experiment/resDBPN2_v4_x4_bs20_96_ps96_k10_res5_L1/model/model_best.pt --test_only --save_results \

--dir_demo ../test/testsmall_jpg #--chop

CUDA_VISIBLE_DEVICES=2 python main.py --n_GPUs 1 --model resDBPN_2 --data_test Demo --scale 2 --num_stages 4 --base_filter 64 --n_resblock 3 --pre_train ../experiment/resDBPN2_small_x2_bs24_64_ps128_k4_res3_L1_SE_stage4_dilated/model/model_best.pt --test_only --save_results
--dir_demo ../test/testsmall_jpg #--chop


import torch import math import torch.nn.functional as F from model.se_module import SELayer

class sub_pixel(torch.nn.Module): def init(self, scale, act=False): super(sub_pixel, self).init() modules = [] modules.append( torch.nn.Conv2d(64, 64 * 64, 3, 1, 1) ) modules.append(torch.nn.PixelShuffle(scale)) self.body = torch.nn.Sequential(*modules) def forward(self, x): x = self.body(x) return x

class Upsampler(torch.nn.Module): def init(self, num_filter, kernel_size=3, stride=1, padding=1, scale=8, bn=False, activation='prelu', bias=True): super(Upsampler, self).init() padding = kernel_size // 2 modules = [] self.act = False self.activation = activation if self.activation == 'relu': self.act = torch.nn.ReLU(True) elif self.activation == 'prelu': self.act = torch.nn.PReLU() elif self.activation == 'lrelu': self.act = torch.nn.LeakyReLU(0.2, True) elif self.activation == 'tanh': self.act = torch.nn.Tanh() elif self.activation == 'sigmoid': self.act = torch.nn.Sigmoid()

    if (scale & (scale - 1)) == 0:    # Is scale = 2^n
        for _ in range(int(math.log(scale, 2))):
            modules.append(
                torch.nn.Conv2d(num_filter, 4 * num_filter, kernel_size, stride, padding, bias=bias)
            )
            modules.append(torch.nn.PixelShuffle(2))
            if bn: modules.append(torch.nn.BatchNorm2d(num_filter))
            if self.act: modules.append(self.act)
    elif scale == 3:
        modules.append(
            torch.nn.Conv2d(num_filter, 9 * num_filter, kernel_size, stride, padding, bias=bias)
        )
        modules.append(torch.nn.PixelShuffle(3))
        if bn: modules.append(torch.nn.BatchNorm2d(num_filter))
        if self.act: modules.append(self.act)
    else:
        raise NotImplementedError

    self.body = torch.nn.Sequential(*modules)

def forward(self, x):
    return self.body(x)

class ConvBlock(torch.nn.Module): def init(self, input_size, output_size, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu', norm=None, dilated=False): super(ConvBlock, self).init() if dilated: self.conv = torch.nn.Conv2d(input_size, output_size, 4, 2, 3, dilation=2) else: self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size, stride, padding, bias=bias)

    self.norm = norm
    if self.norm == 'batch':
        self.bn = torch.nn.BatchNorm2d(output_size)
    elif self.norm == 'instance':
        self.bn = torch.nn.InstanceNorm2d(output_size)

    self.activation = activation
    if self.activation == 'relu':
        self.act = torch.nn.ReLU(True)
    elif self.activation == 'prelu':
        self.act = torch.nn.PReLU()
    elif self.activation == 'lrelu':
        self.act = torch.nn.LeakyReLU(0.2, True)
    elif self.activation == 'tanh':
        self.act = torch.nn.Tanh()
    elif self.activation == 'sigmoid':
        self.act = torch.nn.Sigmoid()

def forward(self, x):
    if self.norm is not None:
        out = self.bn(self.conv(x))
    else:
        out = self.conv(x)

    if self.activation is not None:
        return self.act(out)
    else:
        return out

class DeconvBlock(torch.nn.Module): def init(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, activation='prelu', norm=None): super(DeconvBlock, self).init() # kernel_size = 4 # stride = 2 # padding = 1 # self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, dilation=1, output_padding=0, bias=bias) self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, bias=bias)

    self.norm = norm
    if self.norm == 'batch':
        self.bn = torch.nn.BatchNorm2d(output_size)
    elif self.norm == 'instance':
        self.bn = torch.nn.InstanceNorm2d(output_size)

    self.activation = activation
    if self.activation == 'relu':
        self.act = torch.nn.ReLU(True)
    elif self.activation == 'prelu':
        self.act = torch.nn.PReLU()
    elif self.activation == 'lrelu':
        self.act = torch.nn.LeakyReLU(0.2, True)
    elif self.activation == 'tanh':
        self.act = torch.nn.Tanh()
    elif self.activation == 'sigmoid':
        self.act = torch.nn.Sigmoid()

def forward(self, x):
    if self.norm is not None:
        out = self.bn(self.deconv(x))
    else:
        out = self.deconv(x)

    if self.activation is not None:
        return self.act(out)
    else:
        return out

class ResBlock(torch.nn.Module): def init(self, num_filter, kernel_size=3, stride=1, padding=1, bias=True, activation='prelu'): super(ResBlock, self).init() self.conv1 = torch.nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding, bias=bias) # self.conv2 = torch.nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding, bias=bias)

    self.activation = activation
    if self.activation == 'relu':
        self.act = torch.nn.ReLU(True)
    elif self.activation == 'prelu':
        self.act = torch.nn.PReLU()
    elif self.activation == 'lrelu':
        self.act = torch.nn.LeakyReLU(0.2, True)
    elif self.activation == 'tanh':
        self.act = torch.nn.Tanh()
    elif self.activation == 'sigmoid':
        self.act = torch.nn.Sigmoid()

    self.res_scale = 1

def forward(self, x):
    out = self.conv1(x)
    out = self.act(out)

    return out
#
# def forward(self, x):
#     residual = x
#     out = self.conv1(x)
#     if self.activation is not None:
#         out = self.act(out)
#     out = self.conv2(out)
#
#     out = torch.add(out.mul(self.res_scale), residual)
#     return out

class make_dense(torch.nn.Module): def init(self, nChannels, growthRate, kernel_size=3): super(make_dense, self).init() self.conv = torch.nn.Conv2d(nChannels, growthRate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)

def forward(self, x):
	out = F.relu(self.conv(x))
	# out = torch.cat((x, out), 1)
	return out

Residual dense block (RDB) architecture

class RDB(torch.nn.Module): def init(self, nChannels, nDenselayer, growthRate): super(RDB, self).init() nChannels_ = nChannels modules = [] for i in range(nDenselayer): nFeats = growthRate * (i + 1) modules.append(make_dense(nChannels_, nFeats)) nChannels_ = nFeats self.dense_layers = torch.nn.Sequential(*modules) # self.conv_1x1 = torch.nn.Conv2d(nChannels_, nChannels, kernel_size=1, padding=0, bias=False)

def forward(self, x):
	out = self.dense_layers(x)
	# out = self.conv_1x1(out)
	# out = out + x
	return out

# Residual dense block (RDB) architecture

class RDB(nn.Module):

def init(self, nChannels, nDenselayer, growthRate):

super(RDB, self).init()

nChannels_ = nChannels

modules = []

for i in range(nDenselayer):

modules.append(make_dense(nChannels_, growthRate))

nChannels_ += growthRate

self.dense_layers = nn.Sequential(*modules)

self.conv_1x1 = nn.Conv2d(nChannels_, nChannels, kernel_size=1, padding=0, bias=False)

def forward(self, x):

out = self.dense_layers(x)

out = self.conv_1x1(out)

out = out + x

return out

bool_pixshuffer = False

class UpBlock(torch.nn.Module): def init(self, num_filter, kernel_size=8, stride=4, padding=2, bias=True, activation='prelu', n_resblock = 3, norm=None): super(UpBlock, self).init() # self.up_conv1 = DeconvBlock(num_filter, num_filter, kernel_size, stride, padding, activation, norm=None)

    # #####body:RDB
    # nChannel = num_filter
    # nDenselayer = n_resblock
    # growthRate = num_filter // n_resblock
    # self.body = RDB(nChannel, nDenselayer, growthRate)

    ###body:without dense
    modules_body = [
        ResBlock(num_filter, 3, 1, padding=1, bias=bias, activation=activation) \
        for _ in range(n_resblock)
    ]
    self.body = torch.nn.Sequential(*modules_body)

    self.se = SELayer(num_filter, 16)

    modules_up = []
    if bool_pixshuffer:
        modules_up.append(
			Upsampler(num_filter, 3, 1, 1, scale=stride, bias=bias, activation=activation)
		)
    else:
        modules_up.append(
			# Upsampler(num_filter, 3, 1, 1, scale=stride, bias=bias, activation=activation)
			# sub_pixel(scale=stride)
			DeconvBlock(num_filter, num_filter, kernel_size, stride, padding, activation, norm=None)
		)
        modules_up.append(
            ConvBlock(num_filter, num_filter, 3, 1, 1, activation, norm=None)
			# DeconvBlock(num_filter, num_filter, 3, 1, 1, activation, norm=None)
		)
    self.up1 = torch.nn.Sequential(*modules_up)
    # self.up2 = torch.nn.Sequential(*modules_up)

def forward(self, x):
    # res = self.body(x)
    # x = self.up1(x)
    # res_out = self.up2(res)
    # out = res_out + x
    # return out, res

    res_out = self.body(x)
    res_out = self.se(res_out)
    out = res_out + x
    out = self.up1(out)
    return out, res_out

class UpBlock_D(torch.nn.Module): def init(self, num_filter, kernel_size=8, stride=4, padding=2, bias=True, activation='prelu', n_resblock = 3, norm=None): super(UpBlock_D, self).init() self.cut_channel = ConvBlock(num_filter * 2, num_filter, 1, 1, 0, activation, norm=None)

    # self.up_conv1 = DeconvBlock(num_filter, num_filter, kernel_size, stride, padding, activation, norm=None)


    # #####body:RDB
    # nChannel = num_filter
    # nDenselayer = n_resblock
    # growthRate = num_filter // n_resblock
    # self.body = RDB(nChannel, nDenselayer, growthRate)

    ###body:without dense
    modules_body = [
        ResBlock(num_filter, 3, 1, padding=1, bias=bias, activation=activation) \
        for _ in range(n_resblock)
    ]
    self.body = torch.nn.Sequential(*modules_body)

    self.se = SELayer(num_filter, 16)

    modules_up = []
    if bool_pixshuffer:
        modules_up.append(
			Upsampler(num_filter, 3, 1, 1, scale=stride, bias=bias, activation=activation)
		)
    else:
        modules_up.append(
			# Upsampler(num_filter, 3, 1, 1, scale=stride, bias=bias, activation=activation)
			# sub_pixel(scale=stride)
			DeconvBlock(num_filter, num_filter, kernel_size, stride, padding, activation, norm=None)
		)
        modules_up.append(
            ConvBlock(num_filter, num_filter, 3, 1, 1, activation, norm=None)
			# DeconvBlock(num_filter, num_filter, 3, 1, 1, activation, norm=None)
		)
    self.up1 = torch.nn.Sequential(*modules_up)
    # self.up2 = torch.nn.Sequential(*modules_up)

def forward(self, x, res_x):
    # res = self.body(x)
    # concat_res = torch.cat((res, res_x), 1)
    # concat_res = self.cut_channel(concat_res)
    # res_out = self.up1(concat_res)
    # x = self.up2(x)
    # out = res_out + x
    # return out, res

    res_out = self.body(x)
    concat_res = torch.cat((res_out, res_x), 1)
    res = self.cut_channel(concat_res)
    res = self.se(res)
    out = res + x
    out = self.up1(out)
    return out, res_out

class D_UpBlock(torch.nn.Module): def init(self, num_filter, kernel_size=8, stride=4, padding=2, num_stages=1, bias=True, activation='prelu', n_resblock = 3, norm=None): super(D_UpBlock, self).init() self.cut_channel = ConvBlock(num_filter * num_stages, num_filter, 1, 1, 0, activation, norm=None) self.cut_channel_res = ConvBlock(num_filter * (num_stages + 1), num_filter, 1, 1, 0, activation, norm=None) # self.up_conv1 = DeconvBlock(num_filter, num_filter, kernel_size, stride, padding, activation, norm=None)

    # #####body:RDB
    # nChannel = num_filter
    # nDenselayer = n_resblock
    # growthRate = num_filter // n_resblock
    # self.body = RDB(nChannel, nDenselayer, growthRate)

    modules_body = []
    for _ in range(n_resblock):
        modules_body.append(
            ResBlock(num_filter, 3, 1, 1, bias, activation)
        )
    self.body = torch.nn.Sequential(*modules_body)

    self.se = SELayer(num_filter, 16)

    modules_up = []
    if bool_pixshuffer:
        modules_up.append(
			Upsampler(num_filter, 3, 1, 1, scale=stride, bias=bias, activation=activation)
		)
    else:
        modules_up.append(
			# Upsampler(num_filter, 3, 1, 1, scale=stride, bias=bias, activation=activation)
			# sub_pixel(scale=stride)
			DeconvBlock(num_filter, num_filter, kernel_size, stride, padding, activation, norm=None)
		)
        modules_up.append(
            ConvBlock(num_filter, num_filter, 3, 1, 1, activation, norm=None)
			# DeconvBlock(num_filter, num_filter, 3, 1, 1, activation, norm=None)
		)
    self.up1 = torch.nn.Sequential(*modules_up)
    # self.up2 = torch.nn.Sequential(*modules_up)

def forward(self, x, res_x):
    # x = self.cut_channel(x)
    # res = self.body(x)
    # concat_res = torch.cat((res, res_x), 1)
    # concat_res = self.cut_channel_res(concat_res)
    # x = self.up1(x)
    # res_out = self.up2(concat_res)
    # out = res_out + x
    # return out, res

    x = self.cut_channel(x)
    res_out = self.body(x)
    concat_res = torch.cat((res_out, res_x), 1)
    res = self.cut_channel_res(concat_res)
    res = self.se(res)
    out = res + x
    out = self.up1(out)
    return out, res_out

class DownBlock(torch.nn.Module): def init(self, num_filter, kernel_size=8, stride=4, padding=2, bias=True, activation='prelu', n_resblock = 3, norm=None): super(DownBlock, self).init() self.down_conv1 = ConvBlock(num_filter, num_filter, kernel_size, stride, padding, activation, norm=None, dilated=True) # self.down_conv11 = DeconvBlock(num_filter, num_filter, 3, 1, 1, activation, norm=None) # self.down_conv2 = ConvBlock(num_filter, num_filter, kernel_size, stride, padding, activation, norm=None) # self.down_conv22 = DeconvBlock(num_filter, num_filter, 3, 1, 1, activation, norm=None)

    # modules_body = [
    #     ResBlock(num_filter, 3, 1, padding=1, bias=bias, activation=activation) \
    #     for _ in range(n_resblock)
    # ]
    # self.body = torch.nn.Sequential(*modules_body)

def forward(self, x):
    # res = self.body(x)
    # x = self.down_conv1(x)
    # # x = self.down_conv11(x)
    # res_out = self.down_conv2(res)
    # # res_out = self.down_conv22(res_out)
    # out = res_out + x
    # return out, res

    # res_out = self.body(x)
    # out = res_out + x
    # out = self.down_conv1(out)

    res_out = x
    out = self.down_conv1(x)
    return out, res_out

class D_DownBlock(torch.nn.Module): def init(self, num_filter, kernel_size=8, stride=4, padding=2, num_stages=1, bias=True, activation='prelu', n_resblock = 3, norm=None): super(D_DownBlock, self).init() self.cut_channel = ConvBlock(num_filter * num_stages, num_filter, 1, 1, 0, activation, norm=None) self.cut_channel_res = ConvBlock(num_filter * (num_stages), num_filter, 1, 1, 0, activation, norm=None) self.down_conv1 = ConvBlock(num_filter, num_filter, kernel_size, stride, padding, activation, norm=None, dilated=True) # self.down_conv11 = DeconvBlock(num_filter, num_filter, 3, 1, 1, activation, norm=None) # self.down_conv2 = ConvBlock(num_filter, num_filter, kernel_size, stride, padding, activation, norm=None) # self.down_conv22 = DeconvBlock(num_filter, num_filter, 3, 1, 1, activation, norm=None)

    # modules_body = []
	#
    # for _ in range(n_resblock):
    #     modules_body.append(
    #         ResBlock(num_filter, 3, 1, 1, bias, activation)
    #     )
    # self.body = torch.nn.Sequential(*modules_body)

def forward(self, x, res_x):
    # x = self.cut_channel(x)
    # res = self.body(x)
    # concat_res = torch.cat((res, res_x), 1)
    # concat_res = self.cut_channel_res(concat_res)
    # l1 = self.down_conv1(x)
    # # l1 = self.down_conv11(l1)
    # res_out = self.down_conv2(concat_res)
    # # res_out = self.down_conv22(res_out)
    # out = res_out + l1
    # return out, res

    # x = self.cut_channel(x)
    # res_out = self.body(x)
    # concat_res = torch.cat((res_out, res_x), 1)
    # res = self.cut_channel_res(concat_res)
    # out = res + x
    # out = self.down_conv1(out)

    x = self.cut_channel(x)
    res_out = x
    concat_res = torch.cat((res_out, res_x), 1)
    out = self.cut_channel_res(concat_res)
    out = self.down_conv1(out)
    return out, res_out

About