CompVis / latent-diffusion

High-Resolution Image Synthesis with Latent Diffusion Models

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Bug in task Semntic Map-to-Image, I really need assistance because I don't have many resources to run experiments.

RoCanHet opened this issue · comments

I am retraining the Semantic-Map-to-Image task with a configuration tuned similarly to Table 14 on page 25 of the paper. Below is the configuration and dataloader file with the CelebA-HQ mask dataset. However, I am unable to train successfully, even though I have run it up to around 283k iterations. I am unsure where the error is originating from, and I greatly appreciate the assistance from everyone.
Config:

model:
  base_learning_rate: 4.8e-05
  target: ldm.models.diffusion.ddpm.LatentDiffusion
  params:
    linear_start: 0.0015
    linear_end: 0.0205
    log_every_t: 100
    timesteps: 1000
    loss_type: l1
    first_stage_key: image
    cond_stage_key: segmentation
    image_size: 64
    channels: 3
    concat_mode: true
    cond_stage_trainable: true
    unet_config:
      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        image_size: 64
        in_channels: 6
        out_channels: 3
        model_channels: 128
        attention_resolutions:
        - 32
        - 16
        - 8
        num_res_blocks: 2
        channel_mult:
        - 1
        - 4
        - 8
        num_heads: 8
    first_stage_config:
      target: ldm.models.autoencoder.VQModelInterface
      params:
        embed_dim: 3
        n_embed: 8192
        ddconfig:
          double_z: false
          z_channels: 3
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity
    cond_stage_config:
      target: ldm.modules.encoders.modules.SpatialRescaler
      params:
        n_stages: 2
        in_channels: 19
        out_channels: 3
data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 48
    wrap: false
    num_workers: 10
    train:
      target: ldm.data.i2i.SegmentationBase
      params:
        data_root: '/content/512/img'
        segmentation_root: '/content/512/mask'
        size: 384
        crop_size: 256
    validation:
      target: ldm.data.i2i.SegmentationBase
      params:
        data_root: '/content/512/val/img'
        segmentation_root: '/content/512/val/mask'
        size: 384
        crop_size: 256

Dataloader:

import os
import numpy as np
import cv2
import albumentations
from PIL import Image
from torch.utils.data import Dataset, DataLoader
join = os.path.join

class SegmentationBase(Dataset):
    def __init__(self, data_root, segmentation_root,
                 size=None, crop_size=None,random_crop=False, interpolation="bicubic",
                 n_labels=19
                 ):
        self.n_labels = n_labels
        self.data_root = data_root
        self.segmentation_root = segmentation_root

        self.data = os.listdir(self.data_root)
        self.segmentation = os.listdir(self.segmentation_root)
        self.data.sort()
        self._length = len(self.data)
        self.crop_size = crop_size
        size = None if size is not None and size<=0 else size
        self.size = size
        if self.size is not None:
            self.interpolation = interpolation
            self.interpolation = {
                "nearest": cv2.INTER_NEAREST,
                "bilinear": cv2.INTER_LINEAR,
                "bicubic": cv2.INTER_CUBIC,
                "area": cv2.INTER_AREA,
                "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
            self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
                                                                 interpolation=self.interpolation)
            self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
                                                                        interpolation=cv2.INTER_NEAREST)
            self.center_crop = not random_crop
            if self.center_crop:
                self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
            else:
                self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
            self.preprocessor = self.cropper

    def __len__(self):
        return self._length

    def __getitem__(self, i):
        image = Image.open(join(self.data_root,self.data[i]))
        if not image.mode == "RGB":
            image = image.convert("RGB")
        image = np.array(image).astype(np.uint8)
        if self.size is not None:
            image = self.image_rescaler(image=image)["image"]
        segmentation = Image.open(join(self.segmentation_root,self.data[i][:-3]+'png'))
        segmentation = np.array(segmentation).astype(np.uint8)
        if self.size is not None:
            segmentation = self.segmentation_rescaler(image=segmentation)["image"]

        
        if self.size is not None:
            processed = self.preprocessor(image=image,
                                          mask=segmentation
                                          )
        else:
            processed = {"image": image,
                         "mask": segmentation
                         }
   
        processed["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
        # processed["segmentation"] = (processed["mask"]/127.5 - 1.0).astype(np.float32)

        segmentation = processed["mask"]
        onehot = np.eye(self.n_labels)[segmentation]
        processed["segmentation"] = onehot
        
        return processed

And here is the result of the Diffusion Row. It remains consistent from iteration 0 up to the current iteration, which is the 236,500th iteration:
image