Bug in task Semntic Map-to-Image, I really need assistance because I don't have many resources to run experiments.
RoCanHet opened this issue · comments
DucToan commented
I am retraining the Semantic-Map-to-Image task with a configuration tuned similarly to Table 14 on page 25 of the paper. Below is the configuration and dataloader file with the CelebA-HQ mask dataset. However, I am unable to train successfully, even though I have run it up to around 283k iterations. I am unsure where the error is originating from, and I greatly appreciate the assistance from everyone.
Config:
model:
base_learning_rate: 4.8e-05
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.0015
linear_end: 0.0205
log_every_t: 100
timesteps: 1000
loss_type: l1
first_stage_key: image
cond_stage_key: segmentation
image_size: 64
channels: 3
concat_mode: true
cond_stage_trainable: true
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 64
in_channels: 6
out_channels: 3
model_channels: 128
attention_resolutions:
- 32
- 16
- 8
num_res_blocks: 2
channel_mult:
- 1
- 4
- 8
num_heads: 8
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
params:
embed_dim: 3
n_embed: 8192
ddconfig:
double_z: false
z_channels: 3
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.SpatialRescaler
params:
n_stages: 2
in_channels: 19
out_channels: 3
data:
target: main.DataModuleFromConfig
params:
batch_size: 48
wrap: false
num_workers: 10
train:
target: ldm.data.i2i.SegmentationBase
params:
data_root: '/content/512/img'
segmentation_root: '/content/512/mask'
size: 384
crop_size: 256
validation:
target: ldm.data.i2i.SegmentationBase
params:
data_root: '/content/512/val/img'
segmentation_root: '/content/512/val/mask'
size: 384
crop_size: 256
Dataloader:
import os
import numpy as np
import cv2
import albumentations
from PIL import Image
from torch.utils.data import Dataset, DataLoader
join = os.path.join
class SegmentationBase(Dataset):
def __init__(self, data_root, segmentation_root,
size=None, crop_size=None,random_crop=False, interpolation="bicubic",
n_labels=19
):
self.n_labels = n_labels
self.data_root = data_root
self.segmentation_root = segmentation_root
self.data = os.listdir(self.data_root)
self.segmentation = os.listdir(self.segmentation_root)
self.data.sort()
self._length = len(self.data)
self.crop_size = crop_size
size = None if size is not None and size<=0 else size
self.size = size
if self.size is not None:
self.interpolation = interpolation
self.interpolation = {
"nearest": cv2.INTER_NEAREST,
"bilinear": cv2.INTER_LINEAR,
"bicubic": cv2.INTER_CUBIC,
"area": cv2.INTER_AREA,
"lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
interpolation=self.interpolation)
self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
interpolation=cv2.INTER_NEAREST)
self.center_crop = not random_crop
if self.center_crop:
self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
else:
self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
self.preprocessor = self.cropper
def __len__(self):
return self._length
def __getitem__(self, i):
image = Image.open(join(self.data_root,self.data[i]))
if not image.mode == "RGB":
image = image.convert("RGB")
image = np.array(image).astype(np.uint8)
if self.size is not None:
image = self.image_rescaler(image=image)["image"]
segmentation = Image.open(join(self.segmentation_root,self.data[i][:-3]+'png'))
segmentation = np.array(segmentation).astype(np.uint8)
if self.size is not None:
segmentation = self.segmentation_rescaler(image=segmentation)["image"]
if self.size is not None:
processed = self.preprocessor(image=image,
mask=segmentation
)
else:
processed = {"image": image,
"mask": segmentation
}
processed["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
# processed["segmentation"] = (processed["mask"]/127.5 - 1.0).astype(np.float32)
segmentation = processed["mask"]
onehot = np.eye(self.n_labels)[segmentation]
processed["segmentation"] = onehot
return processed
And here is the result of the Diffusion Row. It remains consistent from iteration 0 up to the current iteration, which is the 236,500th iteration: