atriumlts / subpixel

subpixel: A subpixel convnet for super resolution with Tensorflow

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

how to use the trained model to generate super-res images

neerajBaji opened this issue · comments

Probably a noob question but I have trained the model on the celebA dataset. How should I use this model to generate super-res versions of arbitrary images?

In main.py if the training mode is False, the model is simply loaded from checkpoints. How do I run the inference stage/forward pass of the loaded model?

will add some code for this soon, but basically you have to load the DCGAN instance, use get_image to prepare the image data, and then run it through the model in a session.

Thanks, I will try that out.

Please correct me if I am wrong but I am assuming the tensor to be run in the inference stage is dcgan.G.

But this is setup to accept a batch of images. So to keep things simple for now should I feed in a batch of images during inference as well?

So I got this to work by feeding in a batch of test images. But the results are underwhelming. I have trained the network for 25 epochs while specifying celebA as the dataset. Are there any other hyperparameters to be tuned? I am just trying to get a sense of the network on well known datasets before using it on more custom data.

@neerajBaji Where did you setup the batch, and how?

if input_size=32, should any image be crop to the input_size first, and then combine these output patches together??

@goldsmith Could you provide a code sample for inference?

@goldsmith and @neerajBaji I would also appreciate sample code for simple inference with a single input image. I have the network trained but I'm a bit confused on how to use the network for interference.

I hope the inference with one single image code can be provided also. Thanks.

Hi, many thanks for publishing this work. I see that many people have problem using the model to create a super res image (evaluate) , I am attaching my attempt to solve this.
I have to say, the results are very poor , and in fact when running on the test or validation set, the results are worse than the original image - So the code might not be correct. anyway - here it is:

`
import os
import cv2
import numpy as np
from model import DCGAN
from utils import get_image, image_save , save_images
import tensorflow as tf
from scipy.misc import imresize

flags = tf.app.flags
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
flags.DEFINE_integer("image_size", 128, "The size of image to use")
flags.DEFINE_string("checkpoint_dir", "/home/omer/work/sub_pixel/models", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("test_image_dir", "/home/omer/work/facenet/data/amp_gt/Yair", "Directory name of the images to evaluate")
flags.DEFINE_string("out_dir", "/home/omer/work/sub_pixel/out", "Directory name of the images to evaluate")

FLAGS = flags.FLAGS

def doresize(x, shape):
x = np.copy((x+1.)*127.5).astype("uint8")
y = imresize(x, shape)
return y

def main():
with tf.Session() as sess:
dcgan = DCGAN(sess, image_size=FLAGS.image_size, batch_size=FLAGS.batch_size,
dataset_name='celebA', is_crop=False, checkpoint_dir=FLAGS.checkpoint_dir)
res = dcgan.load(FLAGS.checkpoint_dir)
if not res:
print ("failed loading model from path:" + FLAGS.checkpoint_dir)
return

    i = 0
    files = []
    input_images = np.zeros(shape=(FLAGS.batch_size, 128, 128, 3))
    for f in os.listdir(FLAGS.test_image_dir):

        img_path = os.path.join(FLAGS.test_image_dir, f)
        if os.path.isdir(img_path):
            continue
        img = get_image(img_path, FLAGS.image_size,False)
        files.append(f)
        input_images[i] = img

        if i == FLAGS.batch_size - 1:
            batch_ready(dcgan, input_images, sess, files)

            i = 0
            input_images = np.zeros(shape=(FLAGS.batch_size, 128, 128, 3))
            files = []
            print('done batch')
        else:
            i += 1

def batch_ready(dcgan, input_images, sess, files):
input_resized = [doresize(xx, (32, 32, 3)) for xx in input_images]
sample_input_resized = np.array(input_resized).astype(np.float32)
sample_input_images = np.array(input_images).astype(np.float32)
output_images = sess.run(fetches=[dcgan.G],
feed_dict={dcgan.inputs: sample_input_resized, dcgan.images: sample_input_images})
save_results(output_images, files)

def save_results(output_images, files):
for k in range(0, FLAGS.batch_size):
out_path = os.path.join(FLAGS.out_dir, files[k] + '_.png')
out_img = output_images[0][k]

    # out_correct = ((out_img + 1) * 127.5).astype(np.uint8)
    # out_correct = cv2.cvtColor(out_correct, cv2.COLOR_RGB2BGR)
    # cv2.imshow('image', out_correct)
    # cv2.waitKey(0)

    image_save(out_img, out_path)

if name == 'main':
main()

`