Visualize the similar image pairs
Yeez-lee opened this issue · comments
Yeez-lee commented
Hi,
Thanks for your efforts. I want to ask some questions about visualizations. Can you provide some codes on how to visualize the similar image pairs (like Figure 1 and 10 in your paper)?
Xiangming (Brian) Gu commented
Thank you for your question. Here are some sample codes for visualising the similar image pairs. Basically, you need firstly generate images, and then search the nearest image in the training dataset using KNN distance.
import os
import numpy as np
import torch
import click
import json
import zipfile
import PIL.Image
from tqdm import tqdm
from glob import glob
try:
import pyspng
except ImportError:
pyspng = None
def file_ext(fname):
return os.path.splitext(fname)[1].lower()
def load_cifar10_zip(zip_path):
zip_file = zipfile.ZipFile(zip_path)
all_names = set(zip_file.namelist())
PIL.Image.init()
image_names = sorted(fname for fname in all_names if file_ext(fname) in PIL.Image.EXTENSION)
# load labels
with zip_file.open('dataset.json', 'r') as f:
labels = json.load(f)['labels']
labels_dict = dict(labels)
images = []
labels = []
# load images
for name in tqdm(image_names):
with zip_file.open(name, 'r') as f:
if pyspng is not None and file_ext(name) == '.png':
image = pyspng.load(f.read())
else:
image = np.array(PIL.Image.open(f))
if image.ndim == 2:
image = image[:, :, np.newaxis] # HW => HWC
image = image.transpose(2, 0, 1) # HWC => CHW
# append images
images.append(image[np.newaxis, :, :, :])
# append labels
label = labels_dict[name]
labels.append(label)
images = np.concatenate(images, axis=0)
labels = np.array(labels)
labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
return images, labels
def knn(seed_image, ref_images, k=1):
# seed_image: [C, H, W]
# ref_images: [N, C, H, W]
C, H, W = seed_image.shape
distance = torch.cdist(seed_image.reshape(1, C*H*W), ref_images.reshape(-1, C*H*W)) / np.sqrt(32 * 32 * 3)
nearest_distance = distance.min(dim=1)[0]
nearest_index = distance.min(dim=1)[1]
nearest_image = ref_images[nearest_index]
return nearest_distance, nearest_image
def plot_images(ckpt_folder):
image_folder = os.path.join(ckpt_folder, "mem-tmp")
save_path = os.path.join(ckpt_folder, "gen_image.png")
total_images = []
for i in range(3):
row_images = []
for j in range(8):
index = i * 8 + j
image_path = os.path.join(image_folder, f'{index-index%1000:06d}', f'{index:06d}.png')
with open(image_path, 'rb') as f:
image = pyspng.load(f.read())
row_images.append(image)
if j < 7:
row_images.append(np.zeros((32, 2, 3), dtype=image.dtype))
row_images = np.concatenate(row_images, axis=1)
total_images.append(row_images)
if i < 2:
total_images.append(np.zeros((2, 32*8+2*(8-1), 3), dtype=image.dtype))
total_images = np.concatenate(total_images, axis=0)
PIL.Image.fromarray(total_images, 'RGB').save(save_path)
image_folder = os.path.join(ckpt_folder, "knn-tmp")
save_path = os.path.join(ckpt_folder, "knn_image.png")
total_images = []
for i in range(3):
row_images = []
for j in range(8):
index = i * 8 + j
image_path = os.path.join(image_folder, f'{index:06d}.png')
with open(image_path, 'rb') as f:
image = pyspng.load(f.read())
row_images.append(image)
if j < 7:
row_images.append(np.zeros((32, 2, 3), dtype=image.dtype))
row_images = np.concatenate(row_images, axis=1)
total_images.append(row_images)
if i < 2:
total_images.append(np.zeros((2, 32*8+2*(8-1), 3), dtype=image.dtype))
total_images = np.concatenate(total_images, axis=0)
PIL.Image.fromarray(total_images, 'RGB').save(save_path)
if __name__ == "__main__":
ckpt_folder = ""
plot_images(ckpt_folder)