43

What exactly does the LogisticRegression.predict_proba function return?

In my example I get a result like this:

array([
    [4.65761066e-03, 9.95342389e-01],
    [9.75851270e-01, 2.41487300e-02],
    [9.99983374e-01, 1.66258341e-05]
])

From other calculations, using the sigmoid function, I know, that the second column is the probabilities. The documentation says that the first column is n_samples, but that can't be, because my samples are reviews, which are texts and not numbers. The documentation also says that the second column is n_classes. That certainly can't be, since I only have two classes (namely, +1 and -1) and the function is supposed to be about calculating probabilities of samples really being of a class, but not the classes themselves.

What is the first column really and why it is there?

2 Answers 2

79
4.65761066e-03 + 9.95342389e-01 = 1
9.75851270e-01 + 2.41487300e-02 = 1
9.99983374e-01 + 1.66258341e-05 = 1

The first column is the probability that the entry has the -1 label and the second column is the probability that the entry has the +1 label. Note that classes are ordered as they are in self.classes_.

If you would like to get the predicted probabilities for the positive label only, you can use logistic_model.predict_proba(data)[:,1]. This will yield you the [9.95342389e-01, 2.41487300e-02, 1.66258341e-05] result.

0
1

As iulian explained, each row of predict_proba()'s result is the probabilities that the observation in that row is of each class (and the classes are ordered as they are in lr.classes_).

In fact, it's also intimately tied to predict() in that each row's highest probability class is chosen by predict(). So for any LogisticRegression (or any classifier really), the following is True.

lr = LogisticRegression().fit(X, y)
highest_probability_classes = lr.predict_proba(X).argmax(axis=1)
all(lr.predict(X) == lr.classes_[highest_probability_classes])     # True

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Not the answer you're looking for? Browse other questions tagged or ask your own question.