ecs-vlc / FMix

Official implementation of 'FMix: Enhancing Mixed Sample Data Augmentation'

Home Page:https://arxiv.org/abs/2002.12047

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question: Is there any examole notebook availaible with you with fmix implemetation.

IamSparky opened this issue · comments

I understood the concepts here , but I am unable to figure out how I should use this github repo wth my pytorch dataset even after going throught h colab implementations.

I am using this classification dataset.

class FlowerDataset(Dataset):
    def __init__(self, id , classes , image , img_height , img_width, mean , std , is_valid):
        self.id = id
        self.classes = classes
        self.image = image
        self.is_valid = is_valid
        if self.is_valid == 1:
            self.aug = albumentations.Compose([
               albumentations.Resize(img_height , img_width, always_apply = True) ,
               albumentations.Normalize(mean , std , always_apply = True) 
            ])
        else:
            self.aug = albumentations.Compose([
                albumentations.Resize(img_height , img_width, always_apply = True) ,
                albumentations.Normalize(mean , std , always_apply = True),
                albumentations.ShiftScaleRotate(shift_limit = 0.0625,
                                                scale_limit = 0.1 ,
                                                rotate_limit = 5,
                                                p = 0.9)
            ]) 
        
    def __len__(self):
        return len(self.id)
    
    def __getitem__(self, index):
        id = self.id[index]
        img = np.array(Image.open(io.BytesIO(self.image[index]))) 
        img = cv2.resize(img, dsize=(128, 128), interpolation=cv2.INTER_CUBIC)
        img = self.aug(image = img)['image']
        img = np.transpose(img , (2,0,1)).astype(np.float32)
       
        return torch.tensor(img, dtype = torch.float),int(self.classes[index])

Pleas help

Hi, thanks for the issue 😃

In order for the mixing to happen you need a batch of data. So the way to implement it here would be to generate masks and mix each batch using the sample_and_apply function in fmix.py. That will give you mixed images, you also then need to mix your loss function with the lambdas returned by sample_and_apply, as is done in this example. Those two steps should allow you to train with FMix. Peudo-code would be something like:

alpha, decay_power = 1.0, 3.0

for epoch in range(max_epochs):
    for batch, target in train_loader:
        batch, perm, lambda = sample_and_apply(batch, alpha, decay_power, (128, 128))
        out = my_model(batch)
        loss = F.cross_entropy(out, target) * lambda + F.cross_entropy(out, target[perm]) * (1 - lambda)

Another option would be to use the pytorch-lightning implementation, which just wraps the above in a class.

Let me know if that helps and I can add a notebook or similar with it. Also, if you have any other suggestions for how this could be made easier then they would be much appreciated 👍

please add the sample notebook which will help to understand the classificatin problem with fmix in a better manner.