rasbt / deeplearning-models

A collection of various deep learning architectures, models, and tips

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

multi label

jS5t3r opened this issue · comments

Hello,

I appreciate your work... I would like to upgrade it from binary labels to multi labels.

class CelebaDataset(Dataset):
    """Custom Dataset for loading CelebA face images"""

    def __init__(self, csv_path, img_dir, transform=None):
    
        df = pd.read_csv(csv_path, index_col=0)
        self.img_dir = img_dir
        self.csv_path = csv_path
        self.img_names = df.index.values
        self.y = df['Male'].values # <-- this needs to be changed  with other labels...
        self.transform = transform

What to do if I want to change it to several classes?

This is a column that was in the CSV file. You could create CSV column that contains all the classes of interest and that should do the trick.