|
14 | 14 | #
|
15 | 15 | # Created on Oct 30, 2019
|
16 | 16 |
|
| 17 | +import sklearn |
17 | 18 | from sklearn.base import clone
|
18 | 19 | from sklearn.model_selection import KFold, cross_val_predict, ParameterGrid
|
19 | 20 | from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
|
@@ -90,12 +91,28 @@ def _select_classifier_from_list(candidates, X, A, n_splits=5, seed=None, loss_t
|
90 | 91 | if n_splits >= 2:
|
91 | 92 | cv = KFold(n_splits=n_splits, shuffle=True, random_state=seed)
|
92 | 93 | for model_idx, m in enumerate(candidates):
|
93 |
| - if loss_type == '01': |
94 |
| - pred = cross_val_predict(m, X=X, y=A, cv=cv, params={'sample_weight': class_weight}).reshape(-1) |
| 94 | + if sklearn.__version__ >= "1.4": |
| 95 | + # TODO: at time of writing scikit-learn 1.4.0 is <1 year old. |
| 96 | + # Once matured, you may erase the deprecated `fit_params` and just use `params`. |
| 97 | + if loss_type == '01': |
| 98 | + pred = cross_val_predict( |
| 99 | + m, X=X, y=A, |
| 100 | + cv=cv, params={'sample_weight': class_weight} |
| 101 | + ).reshape(-1) |
| 102 | + else: |
| 103 | + ps = cross_val_predict( |
| 104 | + m, X=X, y=A, |
| 105 | + cv=cv, params={'sample_weight': class_weight}, |
| 106 | + method='predict_proba' |
| 107 | + ) |
| 108 | + pred = ps[:, 1] |
95 | 109 | else:
|
96 |
| - ps = cross_val_predict(m, X=X, y=A, cv=cv, params={'sample_weight': class_weight}, |
97 |
| - method='predict_proba') |
98 |
| - pred = ps[:, 1] |
| 110 | + if loss_type == '01': |
| 111 | + pred = cross_val_predict(m, X=X, y=A, cv=cv, fit_params={'sample_weight': class_weight}).reshape(-1) |
| 112 | + else: |
| 113 | + ps = cross_val_predict(m, X=X, y=A, cv=cv, fit_params={'sample_weight': class_weight}, |
| 114 | + method='predict_proba') |
| 115 | + pred = ps[:, 1] |
99 | 116 | else:
|
100 | 117 | for model_idx, m in enumerate(candidates):
|
101 | 118 | m.fit(X, A, sample_weight=class_weight)
|
|
0 commit comments