ashleve / lightning-hydra-template

PyTorch Lightning + Hydra. A very user-friendly template for ML experimentation. ⚡🔥⚡

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Suggestion with Solution about loss function

Mai0313 opened this issue · comments

commented

I believe it is advantageous for us to separate the loss and weight addition in ./src/models/mnist_module.py.

For the original code, it uses self.criterion = torch.nn.CrossEntropyLoss() to be the only loss function in loss = self.criterion(logits, y).

However, I think there is a better way to do so; if we change loss function to be loss_fns: list[torch.nn.Module], then using list[dict] can provide more flexible for user.
For example, in model step, I do

def model_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Perform a single model step on a batch of data.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.

        :return: A tuple containing (in order):
            - A tensor of losses.
            - A tensor of predictions.
            - A tensor of target labels.
        """
        x, y = batch
        logits = self.forward(x)
        preds = torch.argmax(logits, dim=1)
        losses = {}  # a dict of {loss_fn_name: loss_value}
        losses["total_loss"] = 0.0
        for loss_fn in self.loss_fns:
            losses[loss_fn.tag] = loss_fn(preds, y)
            losses["total_loss"] += losses[loss_fn.tag] * loss_fn.weight
        return losses, preds, y

This revised approach retains the functionality designed by you but allows greater loss function inclusion. Users simply need to populate their custom loss function into src/models/components/loss_fn.py, and the rest is taken care of.

commented

This is just an idea, I am a huge fan of hydra template, so I decided to do my first contribution.
If there is any issue, please feel free to tell me since I wanna do some help ✌🏽

If you guys think this is a good idea to do so, I will fix pytest part.