`model.fit` makes the kernel crash when passing a `class_weight`
atn832 opened this issue · comments
When running model.fit
with a class_weight
, the kernel crashes before it completes the first epoch.
Steps to reproduce:
Run the tutorial at https://www.tensorflow.org/tutorials/structured_data/imbalanced_data from the beginning of the notebook to cell 35:
weighted_history = weighted_model.fit(
train_features,
train_labels,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
callbacks=[early_stopping],
validation_data=(val_features, val_labels),
# The class weights go here
class_weight=class_weight)
If I comment out class_weight=class_weight
, the training runs fine. If I leave it in, it crashes with the following training log:
Epoch 1/100
82/90 [==========================>...] - ETA: 0s - loss: 4.4291 - tp: 60.6707 - fp: 59.4756 - tn: 130275.5366 - fn: 165.3171 - accuracy: 0.9984 - precision: 0.5642 - recall: 0.3115 - auc: 0.7105 - prc: 0.2777
It crashes only on my M1 Mac Mini. If I run the notebook in Colab, it runs fine.