Suggestion with Solution about loss function
Mai0313 opened this issue · comments
I believe it is advantageous for us to separate the loss and weight addition in ./src/models/mnist_module.py
.
For the original code, it uses self.criterion = torch.nn.CrossEntropyLoss()
to be the only loss function in loss = self.criterion(logits, y)
.
However, I think there is a better way to do so; if we change loss function to be loss_fns: list[torch.nn.Module]
, then using list[dict]
can provide more flexible for user.
For example, in model step, I do
def model_step(
self, batch: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform a single model step on a batch of data.
:param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
:return: A tuple containing (in order):
- A tensor of losses.
- A tensor of predictions.
- A tensor of target labels.
"""
x, y = batch
logits = self.forward(x)
preds = torch.argmax(logits, dim=1)
losses = {} # a dict of {loss_fn_name: loss_value}
losses["total_loss"] = 0.0
for loss_fn in self.loss_fns:
losses[loss_fn.tag] = loss_fn(preds, y)
losses["total_loss"] += losses[loss_fn.tag] * loss_fn.weight
return losses, preds, y
This revised approach retains the functionality designed by you but allows greater loss function inclusion. Users simply need to populate their custom loss function into src/models/components/loss_fn.py
, and the rest is taken care of.
This is just an idea, I am a huge fan of hydra template, so I decided to do my first contribution.
If there is any issue, please feel free to tell me since I wanna do some help ✌🏽
If you guys think this is a good idea to do so, I will fix pytest part.