Skip to content

Commit 8c40d54

Browse files
committed
Version-conditional cross_val_predict keep support for sklearn<1.4
Keep supporting scikit-learn <1.4 by picking `fit_params` or `params` in scikit-learn's `cross_va_predict`, based on the pandas version install. Will enable to keep support for Python 3.8 (which is not supported for scikit-learn>1.4) Signed-off-by: Ehud-Karavani <[email protected]>
1 parent d29b34c commit 8c40d54

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

causallib/contrib/adversarial_balancing/classifier_selection.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#
1515
# Created on Oct 30, 2019
1616

17+
import sklearn
1718
from sklearn.base import clone
1819
from sklearn.model_selection import KFold, cross_val_predict, ParameterGrid
1920
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
@@ -90,12 +91,28 @@ def _select_classifier_from_list(candidates, X, A, n_splits=5, seed=None, loss_t
9091
if n_splits >= 2:
9192
cv = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
9293
for model_idx, m in enumerate(candidates):
93-
if loss_type == '01':
94-
pred = cross_val_predict(m, X=X, y=A, cv=cv, params={'sample_weight': class_weight}).reshape(-1)
94+
if sklearn.__version__ >= "1.4":
95+
# TODO: at time of writing scikit-learn 1.4.0 is <1 year old.
96+
# Once matured, you may erase the deprecated `fit_params` and just use `params`.
97+
if loss_type == '01':
98+
pred = cross_val_predict(
99+
m, X=X, y=A,
100+
cv=cv, params={'sample_weight': class_weight}
101+
).reshape(-1)
102+
else:
103+
ps = cross_val_predict(
104+
m, X=X, y=A,
105+
cv=cv, params={'sample_weight': class_weight},
106+
method='predict_proba'
107+
)
108+
pred = ps[:, 1]
95109
else:
96-
ps = cross_val_predict(m, X=X, y=A, cv=cv, params={'sample_weight': class_weight},
97-
method='predict_proba')
98-
pred = ps[:, 1]
110+
if loss_type == '01':
111+
pred = cross_val_predict(m, X=X, y=A, cv=cv, fit_params={'sample_weight': class_weight}).reshape(-1)
112+
else:
113+
ps = cross_val_predict(m, X=X, y=A, cv=cv, fit_params={'sample_weight': class_weight},
114+
method='predict_proba')
115+
pred = ps[:, 1]
99116
else:
100117
for model_idx, m in enumerate(candidates):
101118
m.fit(X, A, sample_weight=class_weight)

0 commit comments

Comments
 (0)