switchablenorms / CelebAMask-HQ

A large-scale face dataset for face parsing, recognition, generation and editing.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

You can input single image, and generate new mask. The code be based on the original 'tester.py' code.

zizhuli opened this issue · comments

import torch.nn as nn
from torchvision.utils import save_image
from torchvision import transforms
import cv2
import PIL
from My_Detector.unet import unet
from My_Detector.utils import *
from PIL import Image
import glob

def transformer(resize, totensor, normalize, centercrop, imsize):
options = []
if centercrop:
options.append(transforms.CenterCrop(160))
if resize:
options.append(transforms.Resize((imsize, imsize), interpolation=PIL.Image.NEAREST))
if totensor:
options.append(transforms.ToTensor())
if normalize:
options.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
transform = transforms.Compose(options)

return transform

def trans_square(image):
r"""Open the image using PIL."""
image = image.convert('RGB')
w, h = image.size
background = Image.new('RGB', size=(max(w, h), max(w, h)), color=(127, 127, 127))
length = int(abs(w - h) // 2)
box = (length, 0) if w < h else (0, length)
background.paste(image, box)
return background

class My_Tester(object):
def init(self):
self.parallel = False
self.test_label_path = './test_results'
self.test_color_label_path = './test_color_visualize'

    self.build_model()

def test(self):
    image_size = 512
    transform = transformer(True, True, True, False, image_size)
    make_folder(self.test_label_path, '')
    make_folder(self.test_color_label_path, '')
    make_folder('./samples', '')
    self.G.load_state_dict(torch.load('./models/model.pth'))
    self.G.eval()

    path = glob.glob(r'./test_img/*.jpg')[0]
    img = Image.open(path)
    img = trans_square(img)
    new_img = img.resize((1024, 1024), PIL.Image.BILINEAR)

    img = transform(img)
    img = img.unsqueeze(0).cuda()

    labels_predict = self.G(img)

    labels_predict_plain = generate_label_plain(labels_predict, image_size)
    labels_predict_color = generate_label(labels_predict, image_size)

    cv2.imwrite(os.path.join(self.test_label_path, 'test.png'), labels_predict_plain[0])
    save_image(labels_predict_color[0], os.path.join(self.test_color_label_path, 'test.png'))

    new_img.save('./samples/Image.jpg')
    cv2.imwrite('./samples/Mask.png', labels_predict_plain[0])

def build_model(self):
    self.G = unet().cuda()
    if self.parallel:
        self.G = nn.DataParallel(self.G)

if name == 'main':
t = My_Tester()
t.test()

can you release the .py files with Separate python file Separate python file
Thanks!
@helindemeng

how did you connect this to main()

just create a .py file under "face_parsing", and copy the code below:

# https://github.com/switchablenorms/CelebAMask-HQ/issues/68

import os
import re
import torch.nn as nn
from torchvision.utils import save_image
from torchvision import transforms
import cv2
import PIL
from tqdm import tqdm
from unet import unet
from utils import *
from PIL import Image
import glob

def transformer(resize, totensor, normalize, centercrop, imsize):
    options = []
    if centercrop:
        options.append(transforms.CenterCrop(160))
    if resize:
        options.append(transforms.Resize((imsize, imsize), interpolation=PIL.Image.NEAREST))
    if totensor:
        options.append(transforms.ToTensor())
    if normalize:
        options.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
    transform = transforms.Compose(options)

    return transform

def create_dir(filepath):
    directory = os.path.dirname(filepath)
    if not os.path.exists(directory):
        os.makedirs(directory)

def trans_square(image):
    r"""Open the image using PIL."""
    image = image.convert('RGB')
    w, h = image.size
    background = Image.new('RGB', size=(max(w, h), max(w, h)), color=(127, 127, 127))
    length = int(abs(w - h) // 2)
    box = (length, 0) if w < h else (0, length)
    background.paste(image, box)
    return background

class My_Tester(object):
    def __init__(self):
        self.parallel = False
        self.test_label_path = './test_results'
        self.test_color_label_path = './test_color_visualize'
        self.build_model()

    def test(self):
        image_size = 512
        transform = transformer(True, True, True, False, image_size)
        make_folder(self.test_label_path, '')
        make_folder(self.test_color_label_path, '')
        make_folder('./samples', '')
        self.G.load_state_dict(torch.load('./models/parsenet/model.pth'))
        self.G.eval()

        paths = glob.glob(r'xxx/*.jpg')  # your image folders
        for path in tqdm(paths):
            img = Image.open(path)
            img = trans_square(img)
            new_img = img.resize((1024, 1024), PIL.Image.BILINEAR)
            img = transform(img)
            img = img.unsqueeze(0).cuda()

            labels_predict = self.G(img)
            labels_predict_plain = generate_label_plain(labels_predict, image_size)
            labels_predict_color = generate_label(labels_predict, image_size)

            # original image
            save_path_image = path.replace(".jpg", "_resized.jpg")
            new_img.save(save_path_image)

            # mask
            save_path_mask = path.replace(".jpg", "_mask.png")
            create_dir(save_path_mask)
            cv2.imwrite(save_path_mask, labels_predict_plain[0])

            # mask color
            save_path_mask_color = path.replace(".jpg", "_mask_color.png")
            create_dir(save_path_mask_color)
            save_image(labels_predict_color[0], save_path_mask_color)

    def build_model(self):
        self.G = unet().cuda()
        if self.parallel:
            self.G = nn.DataParallel(self.G)


if __name__ == '__main__':
    t = My_Tester()
    t.test()