Improve Docs for `from_proba`-esque methods for Keras users
lambda-science opened this issue · comments
Hey,
I'm trying to use doubtlab with my Keras model that does image classification. I have a model that is already trained and a test dataset. I want to find data from my test dataset that have a wrong classification or a good classification with low confidence.
Basically I'm running a very simple
model = tf.keras.models.load_model('data/results/SDH16K_GPU_WITHAUG/model.h5')
doubt = DoubtEnsemble(reason = WrongPredictionReason(model=model))
indices = doubt.get_indices(test_images, test_labels)
And get the following error:
105/105 [==============================] - 10s 50ms/step
File ~/code-project/MyoQuant-SDH-Train/.venv/lib/python3.8/site-packages/doubtlab/reason.py:232, in WrongPredictionReason.from_predict(pred, y, method)
228 raise ValueError(
229 f"Cannot use method={method} when y_true values aren't binary."
230 )
231 if method == "all":
--> 232 return (pred != y).astype(np.float16)
233 if method == "fp":
234 return ((y == 0) & (pred == 1)).astype(np.float16)
AttributeError: 'bool' object has no attribute 'astype'
I've tried wrapping my Keras Model as as Sci-kit classifier (using: https://www.adriangb.com/scikeras/stable/generated/scikeras.wrappers.KerasClassifier.html)
I get a "not fitted" error
sciKeras = KerasClassifier(model)
doubt = DoubtEnsemble(reason = WrongPredictionReason(model=sciKeras))
File ~/code-project/MyoQuant-SDH-Train/.venv/lib/python3.8/site-packages/scikeras/wrappers.py:993, in BaseWrapper._predict_raw(self, X, **kwargs)
991 # check if fitted
992 if not self.initialized_:
--> 993 raise NotFittedError(
994 "Estimator needs to be fit before `predict` " "can be called"
995 )
996 # basic input checks
997 X, _ = self._validate_data(X=X, y=None)
NotFittedError: Estimator needs to be fit before `predict` can be called
I guess DoubtLab is only for Scikit models for now, but I wondered if somebody tried something similar.
Eventually I did it by hand with
def indices_low_conf(model, test_X, test_Y, margin=0.55):
"""Return the indices of the images where the confidence of the model is lower than the margin."""
predictions = model.predict(test_X)
confidence = np.max(predictions, axis=1)
predicted_class = np.argmax(predictions, axis=1)
indices = np.where((confidence < margin))[0]
return indices
def indices_wrong_class_strong_conf(model, test_X, test_Y, margin=0.95):
"""Return the indices of the images where the prediction is wrong AND the confidence of the model is higher than the margin."""
predictions = model.predict(test_X)
confidence = np.max(predictions, axis=1)
predicted_class = np.argmax(predictions, axis=1)
indices = np.where((confidence > margin) & (predicted_class != test_Y))[0]
return indices
idx_low_conf = indices_low_conf(model, test_images, test_labels)
idx_wrong_conf = indices_wrong_class_strong_conf(model, test_images, test_labels)
Have you seen this section of the docs? You can also just use your array of predictions/probas instead of relying on a scikit-learn model.
I could make this more explicit by explaining this on the README as well. But I think that would also work for you, right?
Most of the reasons in doubtlab offer a from_predict
or from_proba
staticmethod that you can call if you don't want to resort to scikit-learn. The API docs shed more light on this.
Meh you're right, I'm just blind ! Sorry for this issue, have a good one ! :)
I'm going to keep it open, because the fact that you didn't find it suggests that it deserves to be more on the fore-front of the docs.
I'll change the topic of this issue to reflect this, it's good feedback!
So there's no way of taking a Keras model and doing stuff online/stream? I have 5000+ classes (hopefully fewer if we clean them up but still) so passing the full proba is a bit cumbersome when the data gets big. I would also argue that it only seems like entropy really needs the full proba but that's not a point worth arguing if you can wrap the model completely. Is PyTorch fine?
So there's no way of taking a Keras model and doing stuff online/stream?
Could you elaborate what you mean with online/stream? Many of our techniques work via from_proba methods too.
Oh, I meant batches outside of what the model itself does. So instead of feeding all the data to the model just give it a small slice like 10_000 rows and then do the doubtlab stuff, discard the proba, and feed it the next 10_000 (or whatever small amount of data). But I am guessing you don't do that.