kazuto1011 / deeplab-pytorch

PyTorch re-implementation of DeepLab v2 on COCO-Stuff / PASCAL VOC datasets

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

custom dataset

soans1994 opened this issue · comments

hello author,

how to build voc like data from custom dataset of raw images and seg masks png.

thank you

The following explains how to create the custom dataset class, inheriting libs.datasets.base._BaseDataset. The class has no content in _set_files() and _load_data(), where you need to instantiate them for your case.

def _set_files(self):
"""
Create a file path/image id list.
"""
raise NotImplementedError()
def _load_data(self, image_id):
"""
Load the image and label in numpy.ndarray
"""
raise NotImplementedError()

  1. Assuming a pair of image and label files is named with a unique ID, we then store all the IDs to a list self.files in _set_files(). For instance, CocoStuff164k extracts the IDs from image paths as follows.

def _set_files(self):
# Create data list by parsing the "images" folder
if self.split in ["train2017", "val2017"]:
file_list = sorted(glob(osp.join(self.root, "images", self.split, "*.jpg")))
assert len(file_list) > 0, "{} has no image".format(
osp.join(self.root, "images", self.split)
)
file_list = [f.split("/")[-1].replace(".jpg", "") for f in file_list]
self.files = file_list
else:
raise ValueError("Invalid split name: {}".format(self.split))

  1. Next, we implement _load_data(), which reads and returns a pair of image and label data from the given index of the IDs. Again, CocoStuff164k may help you.

def _load_data(self, index):
# Set paths
image_id = self.files[index]
image_path = osp.join(self.root, "images", self.split, image_id + ".jpg")
label_path = osp.join(self.root, "annotations", self.split, image_id + ".png")
# Load an image and label
image = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.float32)
label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
return image_id, image, label

That's all. The custom dataset class can fetch and preprocess the data, as implemented in the super class.

def __getitem__(self, index):
image_id, image, label = self._load_data(index)
if self.augment:
image, label = self._augmentation(image, label)
# Mean subtraction
image -= self.mean_bgr
# HWC -> CHW
image = image.transpose(2, 0, 1)
return image_id, image.astype(np.float32), label.astype(np.int64)

@kazuto1011

thank you very much for the detailed explaination.