PyTorch conversion of the excellent post on the same topic in Tensorflow. Simply an implementation of a triple loss with online mining of candidate triplets used in semi-supervised learning.
Include the triplet_loss.py
file in your project and include either batch_hard_triplet_loss
or batch_all_triplet_loss
.
Example usage:
from triplet_loss import batch_hard_triplet_loss
labels = torch.randint(5) # our five labels
embeddings = model(labels)
loss = batch_hard_triplet_loss(labels, embeddings, margin=0.2)
loss.backward()
# and so on
pip install -r requirements.txt
python3 -m pytest test_triplet_loss.py
- Triplet Loss and Online Triplet Mining in Tensorflow
- Facenet paper
- adambielski's nice implementation (unfortunately context switches between CPU / GPU)