xinghaochen / TinySAM

Official PyTorch implementation of "TinySAM: Pushing the Envelope for Efficient Segment Anything Model"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Finetune TinySAM on custom dataset

Riley-livingston opened this issue · comments

Hello, Im trying to fine-tune the mask decoder of tiny sam on a custom dataset while freezing the weights of the image_encoder and prompt_encoder. Im having an issue in my training loop where the sam.forward() requires a "multimask_output" argument but the MaskDecoder.forward() doesn't accept a "multitask_output" argument.

Im not an ML Engineer so I don't know much about the underlying code. If anyone with more knowledge than me has some insight into how I can resolve this issue I would appreciate it, thanks!

here is how im freezing the image encoder and prompt encoder to maintain the original weights:

for name, param in sam_model.named_parameters():
  if name.startswith("image_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

I am also providing bounding box Prompts as the input. Here is my custom class for Dataset creation:

class SAMDataset(Dataset):
    """
    Dataset class for SAM model, serving images with associated bounding boxes and masks,
   
    """
    def __init__(self, dataset, bbox_mapping, sam_model, device='cuda'):
        self.dataset = dataset
        self.bbox_mapping = bbox_mapping
        self.sam_model = sam_model
        self.device = device
        self.target_size = (1024, 1024)  # Adjusted to the expected input size of the model

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Assuming dataset[idx] returns a dict with 'image' and 'label' keys
        pil_image = self.dataset[idx]['image']
        pil_mask = self.dataset[idx]['label']

        image_tensor = to_tensor(np.array(pil_image)).to(self.device)
        mask_tensor = to_tensor(np.array(pil_mask)).to(self.device)

        # Resize image and mask to target size
        image_tensor = resize(image_tensor, self.target_size)
        mask_tensor = resize(mask_tensor, self.target_size)

        # Fetch bounding boxes directly without padding
        bboxes = self.bbox_mapping.get(idx + 1, [])  # Adjust index if necessary
        bboxes_tensor = torch.tensor(bboxes, dtype=torch.float, device=self.device)

        return {
            'image': image_tensor,
            'bboxes': bboxes_tensor,
            'mask': mask_tensor
        }
        
### Create a DataLoader instance for the training dataset
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset,shuffle=True, drop_last=False)

image torch.Size([1, 3, 1024, 1024])
bboxes torch.Size([1, 1, 4])
mask torch.Size([1, 1, 1024, 1024])
`
### Training Loop

num_epochs = 1
device = "cuda"
sam_model.to(device)
sam_model.train()

for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
        # Preparing the batched_input according to the Tiny sam_model's expected input format
        batched_input = [{
        'image': batch['image'].squeeze(0).to(device),
        'bboxes': batch['bboxes'].squeeze(0).to(device)
    }]
        # forward pass
        outputs_list = sam_model(batched_input, multimask_output = True)

        # Assuming the outputs_list is a list of dictionaries, and you need to aggregate the masks for loss computation
        # Here, you'd need to adapt the code to match the structure of your outputs
        predicted_masks = torch.stack([output['pred_mask'] for output in outputs_list]).squeeze(0)
        ground_truth_masks = batch["mask"].float().squeeze(1).to(device)

        loss = seg_loss(predicted_masks, ground_truth_masks)

        # backward pass (compute gradients of parameters)
        optimizer.zero_grad()
        loss.backward()

        # optimize
        optimizer.step()
        epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')


error when I DONT provide multitask_output:

TypeError                                 Traceback (most recent call last)
<ipython-input-108-f41ebba752d9> in <cell line: 12>()
     21 
     22         # forward pass
---> 23         outputs_list = sam_model(batched_input)
     24 
     25         # Assuming the outputs_list is a list of dictionaries, and you need to aggregate the masks for loss computation

2 frames
/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
    113     def decorate_context(*args, **kwargs):
    114         with ctx_factory():
--> 115             return func(*args, **kwargs)
    116 
    117     return decorate_context

TypeError: Sam.forward() missing 1 required positional argument: 'multimask_output'

error when I do provide the multitask_output argument:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-42-9d874c2eda3d>](https://bvvo9qsh5t-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240325-094203_RC00_618870756#) in <cell line: 12>()
     19     }]
     20         # forward pass
---> 21         outputs_list = sam_model(batched_input, multimask_output = True)
     22 
     23         # Assuming the outputs_list is a list of dictionaries, and you need to aggregate the masks for loss computation

5 frames
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://bvvo9qsh5t-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240325-094203_RC00_618870756#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

TypeError: MaskDecoder.forward() got an unexpected keyword argument 'multimask_output'
 

Hi, simply removing multimask_output from all codes should work well (1492efb). You can pull the newest codes and try again.

You can refer issue #9 for more details.