sail-sg / DiffMemorize

On Memorization in Diffusion Models

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Visualize the similar image pairs

Yeez-lee opened this issue · comments

Hi,

Thanks for your efforts. I want to ask some questions about visualizations. Can you provide some codes on how to visualize the similar image pairs (like Figure 1 and 10 in your paper)?

Thank you for your question. Here are some sample codes for visualising the similar image pairs. Basically, you need firstly generate images, and then search the nearest image in the training dataset using KNN distance.

import os
import numpy as np
import torch
import click
import json
import zipfile
import PIL.Image
from tqdm import tqdm
from glob import glob
try:
    import pyspng
except ImportError:
    pyspng = None

def file_ext(fname):
    return os.path.splitext(fname)[1].lower()

def load_cifar10_zip(zip_path):
    zip_file = zipfile.ZipFile(zip_path)
    all_names = set(zip_file.namelist())
    
    PIL.Image.init()
    image_names = sorted(fname for fname in all_names if file_ext(fname) in PIL.Image.EXTENSION)

    # load labels
    with zip_file.open('dataset.json', 'r') as f:
        labels = json.load(f)['labels']
    
    labels_dict = dict(labels)

    images = []
    labels = []
    
    # load images
    for name in tqdm(image_names):
        with zip_file.open(name, 'r') as f:
            if pyspng is not None and file_ext(name) == '.png':
                image = pyspng.load(f.read())
            else:
                image = np.array(PIL.Image.open(f))
        if image.ndim == 2:
            image = image[:, :, np.newaxis]  # HW => HWC
        image = image.transpose(2, 0, 1)     # HWC => CHW

        # append images
        images.append(image[np.newaxis, :, :, :])

        # append labels
        label = labels_dict[name]
        labels.append(label)

    images = np.concatenate(images, axis=0)
    labels = np.array(labels)
    labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])

    return images, labels

def knn(seed_image, ref_images, k=1):
    # seed_image: [C, H, W]
    # ref_images: [N, C, H, W]
    C, H, W = seed_image.shape
    distance = torch.cdist(seed_image.reshape(1, C*H*W), ref_images.reshape(-1, C*H*W)) / np.sqrt(32 * 32 * 3)
    nearest_distance = distance.min(dim=1)[0]
    nearest_index = distance.min(dim=1)[1]
    nearest_image = ref_images[nearest_index]
    return nearest_distance, nearest_image

def plot_images(ckpt_folder):
    image_folder = os.path.join(ckpt_folder, "mem-tmp")
    save_path = os.path.join(ckpt_folder, "gen_image.png")
    total_images = []
    for i in range(3):
        row_images = []
        for j in range(8):
            index = i * 8 + j
            image_path = os.path.join(image_folder, f'{index-index%1000:06d}', f'{index:06d}.png')
            with open(image_path, 'rb') as f:
                image = pyspng.load(f.read())
                row_images.append(image)
            if j < 7:
                row_images.append(np.zeros((32, 2, 3), dtype=image.dtype))

        row_images = np.concatenate(row_images, axis=1)
        total_images.append(row_images)
        if i < 2:
            total_images.append(np.zeros((2, 32*8+2*(8-1), 3), dtype=image.dtype))
    total_images = np.concatenate(total_images, axis=0)
    PIL.Image.fromarray(total_images, 'RGB').save(save_path)

    image_folder = os.path.join(ckpt_folder, "knn-tmp")

    save_path = os.path.join(ckpt_folder, "knn_image.png")
    total_images = []
    for i in range(3):
        row_images = []
        for j in range(8):
            index = i * 8 + j
            image_path = os.path.join(image_folder, f'{index:06d}.png')
            with open(image_path, 'rb') as f:
                image = pyspng.load(f.read())
                row_images.append(image)
            if j < 7:
                row_images.append(np.zeros((32, 2, 3), dtype=image.dtype))

        row_images = np.concatenate(row_images, axis=1)
        total_images.append(row_images)
        if i < 2:
            total_images.append(np.zeros((2, 32*8+2*(8-1), 3), dtype=image.dtype))
    total_images = np.concatenate(total_images, axis=0)
    PIL.Image.fromarray(total_images, 'RGB').save(save_path)

if __name__ == "__main__":
    ckpt_folder = ""
    plot_images(ckpt_folder)