You can input single image, and generate new mask. The code be based on the original 'tester.py' code.
zizhuli opened this issue · comments
import torch.nn as nn
from torchvision.utils import save_image
from torchvision import transforms
import cv2
import PIL
from My_Detector.unet import unet
from My_Detector.utils import *
from PIL import Image
import glob
def transformer(resize, totensor, normalize, centercrop, imsize):
options = []
if centercrop:
options.append(transforms.CenterCrop(160))
if resize:
options.append(transforms.Resize((imsize, imsize), interpolation=PIL.Image.NEAREST))
if totensor:
options.append(transforms.ToTensor())
if normalize:
options.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
transform = transforms.Compose(options)
return transform
def trans_square(image):
r"""Open the image using PIL."""
image = image.convert('RGB')
w, h = image.size
background = Image.new('RGB', size=(max(w, h), max(w, h)), color=(127, 127, 127))
length = int(abs(w - h) // 2)
box = (length, 0) if w < h else (0, length)
background.paste(image, box)
return background
class My_Tester(object):
def init(self):
self.parallel = False
self.test_label_path = './test_results'
self.test_color_label_path = './test_color_visualize'
self.build_model()
def test(self):
image_size = 512
transform = transformer(True, True, True, False, image_size)
make_folder(self.test_label_path, '')
make_folder(self.test_color_label_path, '')
make_folder('./samples', '')
self.G.load_state_dict(torch.load('./models/model.pth'))
self.G.eval()
path = glob.glob(r'./test_img/*.jpg')[0]
img = Image.open(path)
img = trans_square(img)
new_img = img.resize((1024, 1024), PIL.Image.BILINEAR)
img = transform(img)
img = img.unsqueeze(0).cuda()
labels_predict = self.G(img)
labels_predict_plain = generate_label_plain(labels_predict, image_size)
labels_predict_color = generate_label(labels_predict, image_size)
cv2.imwrite(os.path.join(self.test_label_path, 'test.png'), labels_predict_plain[0])
save_image(labels_predict_color[0], os.path.join(self.test_color_label_path, 'test.png'))
new_img.save('./samples/Image.jpg')
cv2.imwrite('./samples/Mask.png', labels_predict_plain[0])
def build_model(self):
self.G = unet().cuda()
if self.parallel:
self.G = nn.DataParallel(self.G)
if name == 'main':
t = My_Tester()
t.test()
can you release the .py files with Separate python file Separate python file
Thanks!
@helindemeng
how did you connect this to main()
just create a .py file under "face_parsing", and copy the code below:
# https://github.com/switchablenorms/CelebAMask-HQ/issues/68
import os
import re
import torch.nn as nn
from torchvision.utils import save_image
from torchvision import transforms
import cv2
import PIL
from tqdm import tqdm
from unet import unet
from utils import *
from PIL import Image
import glob
def transformer(resize, totensor, normalize, centercrop, imsize):
options = []
if centercrop:
options.append(transforms.CenterCrop(160))
if resize:
options.append(transforms.Resize((imsize, imsize), interpolation=PIL.Image.NEAREST))
if totensor:
options.append(transforms.ToTensor())
if normalize:
options.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
transform = transforms.Compose(options)
return transform
def create_dir(filepath):
directory = os.path.dirname(filepath)
if not os.path.exists(directory):
os.makedirs(directory)
def trans_square(image):
r"""Open the image using PIL."""
image = image.convert('RGB')
w, h = image.size
background = Image.new('RGB', size=(max(w, h), max(w, h)), color=(127, 127, 127))
length = int(abs(w - h) // 2)
box = (length, 0) if w < h else (0, length)
background.paste(image, box)
return background
class My_Tester(object):
def __init__(self):
self.parallel = False
self.test_label_path = './test_results'
self.test_color_label_path = './test_color_visualize'
self.build_model()
def test(self):
image_size = 512
transform = transformer(True, True, True, False, image_size)
make_folder(self.test_label_path, '')
make_folder(self.test_color_label_path, '')
make_folder('./samples', '')
self.G.load_state_dict(torch.load('./models/parsenet/model.pth'))
self.G.eval()
paths = glob.glob(r'xxx/*.jpg') # your image folders
for path in tqdm(paths):
img = Image.open(path)
img = trans_square(img)
new_img = img.resize((1024, 1024), PIL.Image.BILINEAR)
img = transform(img)
img = img.unsqueeze(0).cuda()
labels_predict = self.G(img)
labels_predict_plain = generate_label_plain(labels_predict, image_size)
labels_predict_color = generate_label(labels_predict, image_size)
# original image
save_path_image = path.replace(".jpg", "_resized.jpg")
new_img.save(save_path_image)
# mask
save_path_mask = path.replace(".jpg", "_mask.png")
create_dir(save_path_mask)
cv2.imwrite(save_path_mask, labels_predict_plain[0])
# mask color
save_path_mask_color = path.replace(".jpg", "_mask_color.png")
create_dir(save_path_mask_color)
save_image(labels_predict_color[0], save_path_mask_color)
def build_model(self):
self.G = unet().cuda()
if self.parallel:
self.G = nn.DataParallel(self.G)
if __name__ == '__main__':
t = My_Tester()
t.test()