TripletLoss will produce only NaN when using mixed bfloat16 precision
kbelenky opened this issue · comments
kbelenky commented
If you use the TripletLoss loss function with mixed bfloat16 precision:
mixed_precision.set_global_policy('mixed_bfloat16')
As described here:
https://www.tensorflow.org/guide/mixed_precision#setting_the_dtype_policy
Then the value for the loss will only ever be NaN.
The expected behavior is for the loss value to be a non-NaN value.