
weka.classifiers.sklearn.ScikitLearnClassifier Maven / Gradle / Ivy
Go to download
Integration with CPython for Weka. Python version 2.7.x or higher is required. Also requires the following packages to be installed in python: numpy, pandas, matplotlib and scikit-learn. This package provides a wrapper classifier and clusterer that, between them, cover 60+ scikit-learn algorithms. It also provides a general scripting step for the Knowlege Flow along with scripting plugin environments for the Explorer and Knowledge Flow.
The newest version!
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* ScikitLearnClassifier.java
* Copyright (C) 2015 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.sklearn;
import java.util.List;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.rules.ZeroR;
import weka.core.Attribute;
import weka.core.AttributeStats;
import weka.core.BatchPredictor;
import weka.core.Capabilities;
import weka.core.CapabilitiesHandler;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionMetadata;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.core.WekaException;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.python.PythonSession;
/**
* Wrapper classifier for classifiers and regressors implemented in the
* scikit-learn Python package.
*
* @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
* @version $Revision: $
*/
public class ScikitLearnClassifier extends AbstractClassifier
implements BatchPredictor, CapabilitiesHandler {
protected static final String TRAINING_DATA_ID = "scikit_classifier_training";
protected static final String TEST_DATA_ID = "scikit_classifier_test";
protected static final String MODEL_ID = "weka_scikit_learner";
/** For serialization */
private static final long serialVersionUID = -6212485658537766441L;
/** Holds info on the different learners available */
public static enum Learner {
DecisionTreeClassifier("tree", true, false, true, false,
"\tclass_weight=None, criterion='gini', max_depth=None,\n"
+ "\tmax_features=None, max_leaf_nodes=None, min_samples_leaf=1,\n"
+ "\tmin_samples_split=2, min_weight_fraction_leaf=0.0,\n"
+ "\trandom_state=None, splitter='best'"),
DecisionTreeRegressor("tree", false, true, false, false,
"\tcriterion='mse', max_depth=None, max_features=None,\n"
+ "\tmax_leaf_nodes=None, min_samples_leaf=1, min_samples_split=2,\n"
+ "\tmin_weight_fraction_leaf=0.0, random_state=None,\n"
+ "\tsplitter='best'"),
GaussianNB("naive_bayes", true, false, true, false, ""),
MultinomialNB("naive_bayes", true, false, true, false,
"alpha=1.0, class_prior=None, fit_prior=True"),
BernoulliNB("naive_bayes", true, false, true, false,
"alpha=1.0, binarize=0.0, class_prior=None, fit_prior=True"),
LDA("lda", true, false, true, false,
"\tn_components=None, priors=None, shrinkage=None, solver='svd',\n"
+ "\tstore_covariance=False, tol=0.0001"),
QDA("qda", true, false, true, false, "\tpriors=None, reg_param=0.0"),
LogisticRegression("linear_model", true, false, true, false,
"\tC=1.0, class_weight=None, dual=False, fit_intercept=True,\n"
+ "\tintercept_scaling=1, max_iter=100, multi_class='ovr',\n"
+ "\tpenalty='l2', random_state=None, solver='liblinear', tol=0.0001,\n"
+ "\tverbose=0"),
LogisticRegressionCV("linear_model", true, false, true, false,
"\tCs=10, class_weight=None, cv=None, dual=False,\n"
+ "\tfit_intercept=True, intercept_scaling=1.0, max_iter=100,\n"
+ "\tmulti_class='ovr', n_jobs=1, penalty='l2', refit=True,\n"
+ "\tscoring=None, solver='lbfgs', tol=0.0001, verbose=0"),
LinearRegression("linear_model", false, true, false, false,
"\tcopy_X=True, fit_intercept=True, n_jobs=1, normalize=False"),
ARDRegression("linear_model", false, true, false, false,
"\talpha_1=1e-06, alpha_2=1e-06, compute_score=False, copy_X=True,\n"
+ "\tfit_intercept=True, lambda_1=1e-06, lambda_2=1e-06, n_iter=300,\n"
+ "\tnormalize=False, threshold_lambda=10000.0, tol=0.001, verbose=False"),
BayesianRidge("linear_model", false, true, false, false,
"\talpha_1=1e-06, alpha_2=1e-06, compute_score=False, copy_X=True,\n"
+ "\tfit_intercept=True, lambda_1=1e-06, lambda_2=1e-06, n_iter=300,\n"
+ "\tnormalize=False, tol=0.001, verbose=False"),
ElasticNet("linear_model", false, true, false, false,
"\talpha=1.0, copy_X=True, fit_intercept=True, l1_ratio=0.5,\n"
+ "\tmax_iter=1000, normalize=False, positive=False, precompute=False,\n"
+ "\trandom_state=None, selection='cyclic', tol=0.0001, warm_start=False"),
Lars("linear_model", false, true, false, false,
"\tcopy_X=True, eps=2.2204460492503131e-16, fit_intercept=True,\n"
+ "\tfit_path=True, n_nonzero_coefs=500, normalize=True, precompute='auto',\n"
+ "\tverbose=False"),
LarsCV("linear_model", false, true, false, false,
"\tcopy_X=True, cv=None, eps=2.2204460492503131e-16, fit_intercept=True,\n"
+ "\tmax_iter=500, max_n_alphas=1000, n_jobs=1, normalize=True,\n"
+ "\tprecompute='auto', verbose=False"),
Lasso("linear_model", false, true, false, false,
"\talpha=1.0, copy_X=True, fit_intercept=True, max_iter=1000,\n"
+ "\tnormalize=False, positive=False, precompute=False, random_state=None,\n"
+ "\tselection='cyclic', tol=0.0001, warm_start=False"),
LassoCV("linear_model", false, true, false, false,
"\talphas=None, copy_X=True, cv=None, eps=0.001, fit_intercept=True,\n"
+ "\tmax_iter=1000, n_alphas=100, n_jobs=1, normalize=False, positive=False,\n"
+ "\tprecompute='auto', random_state=None, selection='cyclic', tol=0.0001,\n"
+ "\tverbose=False"),
LassoLars("linear_model", false, true, false, false,
"\talpha=1.0, copy_X=True, eps=2.2204460492503131e-16,\n"
+ "\tfit_intercept=True, fit_path=True, max_iter=500, normalize=True,\n"
+ "\tprecompute='auto', verbose=False"),
LassoLarsCV("linear_model", false, true, false, false,
"\tcopy_X=True, cv=None, eps=2.2204460492503131e-16,\n"
+ "\tfit_intercept=True, max_iter=500, max_n_alphas=1000, n_jobs=1,\n"
+ "\tnormalize=True, precompute='auto', verbose=False"),
LassoLarsIC("linear_model", false, true, false, false,
"\tcopy_X=True, criterion='aic', eps=2.2204460492503131e-16,\n"
+ "\tfit_intercept=True, max_iter=500, normalize=True, precompute='auto',\n"
+ "\tverbose=False"),
MLPClassifier("neural_network", true, false, true, false,
"\thidden_layer_sizes=(100,), " + "activation='relu', solver='adam',\n"
+ "\talpha=0.0001, batch_size='auto', learning_rate='constant',\n"
+ "\tlearning_rate_init=0.001, power_t=0.5, max_iter=200,\n"
+ "\tshuffle=True, random_state=None, tol=0.0001, verbose=False,\n"
+ "\twarm_start=False, momentum=0.9, nesterovs_momentum=True,\n"
+ "\tearly_stopping=False, validation_fraction=0.1, beta_1=0.9,\n"
+ "\tbeta_2=0.999, epsilon=1e-08"),
MLPRegressor("neural_network", false, true, false, false,
"\thidden_layer_sizes=(100,), " + "activation='relu', solver='adam',\n"
+ "\talpha=0.0001, batch_size='auto', learning_rate='constant',\n"
+ "\tlearning_rate_init=0.001, power_t=0.5, max_iter=200,\n"
+ "\tshuffle=True, random_state=None, tol=0.0001, verbose=False,\n"
+ "\twarm_start=False, momentum=0.9, nesterovs_momentum=True,\n"
+ "\tearly_stopping=False, validation_fraction=0.1, beta_1=0.9,\n"
+ "\tbeta_2=0.999, epsilon=1e-08"),
OrthogonalMatchingPursuit("linear_model", false, true, false, false,
"\tfit_intercept=True, n_nonzero_coefs=None,\n"
+ "\tnormalize=True, precompute='auto', tol=None"),
OrthogonalMatchingPursuitCV("linear_model", false, true, false, false,
"\tcopy=True, cv=None, fit_intercept=True,\n"
+ "\tmax_iter=None, n_jobs=1, normalize=True, verbose=False"),
PassiveAggressiveClassifier("linear_model", true, false, false, false,
"\tC=1.0, fit_intercept=True, loss='hinge', n_iter=5,\n"
+ "\tn_jobs=1, random_state=None, shuffle=True, verbose=0,\n"
+ "\twarm_start=False"),
PassiveAggressiveRegressor("linear_model", false, true, false, false,
"\tC=1.0, class_weight=None, epsilon=0.1,\n"
+ "\tfit_intercept=True, loss='epsilon_insensitive', n_iter=5,\n"
+ "\trandom_state=None, shuffle=True, verbose=0, warm_start=False"),
Perceptron("linear_model", true, false, false, false,
"\talpha=0.0001, class_weight=None, " + "eta0=1.0, fit_intercept=True,\n"
+ "\tn_iter=5, n_jobs=1, penalty=None, random_state=0, shuffle=True,\n"
+ "\tverbose=0, warm_start=False"),
RANSACRegressor("linear_model", false, true, false, false,
"\tbase_estimator=None, is_data_valid=None, is_model_valid=None,\n"
+ "\tmax_trials=100, min_samples=None, random_state=None,\n"
+ "\tresidual_metric=None, residual_threshold=None, stop_n_inliers=inf,\n"
+ "\tstop_probability=0.99, stop_score=inf"),
Ridge("linear_model", false, true, false, false,
"\talpha=1.0, copy_X=True, fit_intercept=True, max_iter=None,\n"
+ "\tnormalize=False, solver='auto', tol=0.001"),
RidgeClassifier("linear_model", true, false, false, false,
"\talpha=1.0, class_weight=None, copy_X=True, fit_intercept=True,\n"
+ "\tmax_iter=None, normalize=False, solver='auto', tol=0.001"),
RidgeClassifierCV("linear_model", true, false, false, false,
"alphas=array([ 0.1, 1. , 10. ]), class_weight=None,\n"
+ "\tcv=None, fit_intercept=True, normalize=False, scoring=None"),
RidgeCV("linear_model", false, true, false, false,
"alphas=array([ 0.1, 1. , 10. ]), cv=None, fit_intercept=True,\n"
+ "\tgcv_mode=None, normalize=False, scoring=None, store_cv_values=False"),
SGDClassifier("linear_model", true, false, false, false,
"\talpha=0.0001, average=False, class_weight=None, epsilon=0.1,\n"
+ "\teta0=0.0, fit_intercept=True, l1_ratio=0.15,\n"
+ "\tlearning_rate='optimal', loss='hinge', n_iter=5, n_jobs=1,\n"
+ "\tpenalty='l2', power_t=0.5, random_state=None, shuffle=True,\n"
+ "\tverbose=0, warm_start=False") {
@Override
public boolean producesProbabilities(String params) {
return params.contains("log") || params.contains("modified_huber");
}
},
SGDRegressor("linear_model", false, true, false, false,
"\talpha=0.0001, average=False, epsilon=0.1, eta0=0.01,\n"
+ "\tfit_intercept=True, l1_ratio=0.15, learning_rate='invscaling',\n"
+ "\tloss='squared_loss', n_iter=5, penalty='l2', power_t=0.25,\n"
+ "\trandom_state=None, shuffle=True, verbose=0, warm_start=False"),
TheilSenRegressor("linear_model", false, true, false, false,
"\tcopy_X=True, fit_intercept=True, max_iter=300,\n"
+ "\tmax_subpopulation=10000, n_jobs=1, n_subsamples=None,\n"
+ "\trandom_state=None, tol=0.001, verbose=False"),
GaussianProcessRegressor("gaussian_process", false, true, false, false,
"\tkernel=None, alpha=1e-10,\n"
+ "\toptimizer='fmin_l_bfgs_b', n_restarts_optimizer=0, normalize_y=False,\n "
+ "\trandom_start=1,\n "
+ "\tnormalize=False, random_state=None"),
GaussianProcessClassifier("gaussian_process", true, false, true, false,
"\tkernel=None, optimizer='fmin_l_bfgs_b', n_restarts_optimizer=0, normalize_y=False,\n "
+ "\trandom_start=1, normalize=False, max_iter_predict=100, multi_class='one_vs_rest', random_state=None"),
KernelRidge("kernel_ridge", false, true, false, false,
"\talpha=1, coef0=1, degree=3, gamma=None, kernel='linear',\n"
+ "\tkernel_params=None"),
KNeighborsClassifier("neighbors", true, false, true, false,
"\talgorithm='auto', leaf_size=30, metric='minkowski',\n"
+ "\tmetric_params=None, n_neighbors=5, p=2, weights='uniform'"),
RadiusNeighborsClassifier("neighbors", true, false, false, false,
"\talgorithm='auto', leaf_size=30, metric='minkowski',\n"
+ "\tmetric_params=None, outlier_label=None, p=2, radius=1.0,\n"
+ "\tweights='uniform'"),
KNeighborsRegressor("neighbors", false, true, false, false,
"algorithm='auto', leaf_size=30, metric='minkowski',\n"
+ "\tmetric_params=None, n_neighbors=5, p=2, weights='uniform'"),
RadiusNeighborsRegressor("neighbors", false, true, false, false, ""),
SVC("svm", true, false, false, false,
"\tC=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,\n"
+ "\tkernel='rbf', max_iter=-1, probability=False, random_state=None,\n"
+ "\tshrinking=True, tol=0.001, verbose=False"),
LinearSVC("svm", true, false, false, false,
"\tC=1.0, class_weight=None, dual=True, fit_intercept=True,\n"
+ "\tintercept_scaling=1, loss='squared_hinge', max_iter=1000,\n"
+ "\tmulti_class='ovr', penalty='l2', random_state=None, tol=0.0001,\n"
+ "\tverbose=0"),
NuSVC("svm", true, false, false, false,
"\tcache_size=200, coef0=0.0, degree=3, gamma=0.0, kernel='rbf',\n"
+ "\tmax_iter=-1, nu=0.5, probability=False, random_state=None,\n"
+ "\tshrinking=True, tol=0.001, verbose=False"),
SVR("svm", false, true, false, false,
"\tC=1.0, cache_size=200, coef0=0.0, degree=3, epsilon=0.1, gamma=0.0,\n"
+ "\tkernel='rbf', max_iter=-1, shrinking=True, tol=0.001, verbose=False"),
LinearSVR("svm", false, true, false, false,
"\tC=1.0, dual=True, fit_intercept=True,\n"
+ "\tintercept_scaling=1, loss='epsilon_insensitive', max_iter=1000,\n"
+ "\trandom_state=None, tol=0.0001,\n"
+ "\tverbose=0"),
NuSVR("svm", false, true, false, false,
"\tC=1.0, cache_size=200, coef0=0.0, degree=3, gamma=0.0, kernel='rbf',\n"
+ "\tmax_iter=-1, nu=0.5, shrinking=True, tol=0.001, verbose=False"),
AdaBoostClassifier("ensemble", true, false, true, false,
"\talgorithm='SAMME.R', base_estimator=None,\n"
+ "\tlearning_rate=1.0, n_estimators=50, random_state=None"),
AdaBoostRegressor("ensemble", false, true, false, false,
"\tbase_estimator=None, learning_rate=1.0, loss='linear',\n"
+ "\tn_estimators=50, random_state=None"),
BaggingClassifier("ensemble", true, false, true, false,
"\tbase_estimator=None, bootstrap=True,\n"
+ "\tbootstrap_features=False, max_features=1.0, max_samples=1.0,\n"
+ "\tn_estimators=10, n_jobs=1, oob_score=False, random_state=None,\n"
+ "\tverbose=0"),
BaggingRegressor("ensemble", false, true, false, false,
"\tbase_estimator=None, bootstrap=True,\n"
+ "\tbootstrap_features=False, max_features=1.0, max_samples=1.0,\n"
+ "\tn_estimators=10, n_jobs=1, oob_score=False, random_state=None,\n"
+ "\tverbose=0"),
ExtraTreeClassifier("tree", true, false, true, false,
"\tclass_weight=None, criterion='gini', max_depth=None,\n"
+ "\tmax_features='auto', max_leaf_nodes=None, min_samples_leaf=1,\n"
+ "\tmin_samples_split=2, min_weight_fraction_leaf=0.0,\n"
+ "\trandom_state=None, splitter='random'"),
ExtraTreeRegressor("tree", false, true, false, false,
"\tcriterion='mse', max_depth=None, max_features='auto',\n"
+ "\tmax_leaf_nodes=None, min_samples_leaf=1, min_samples_split=2,\n"
+ "\tmin_weight_fraction_leaf=0.0, random_state=None,\n"
+ "\tsplitter='random'"),
ExtraTreesClassifier("ensemble", true, false, true, false,
"\tn_estimators=100, criterion='gini', max_depth=None,\n"
+ "\tmin_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0,\n"
+ "\tmax_features='auto', max_leaf_nodes=None, min_impurity_decrease=0.0,\n"
+ "\tbootstrap=False, oob_score=False, n_jobs=None,\n"
+ "\trandom_state=None, verbose=0, warm_start=False, class_weight=None,\n"),
ExtraTreesRegressor("ensemble", false, true, false, false,
"\tn_estimators=100, criterion='mse', max_depth=None,\n"
+ "\tmin_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0,\n"
+ "\tmax_features='auto', max_leaf_nodes=None, min_impurity_decrease=0.0,\n"
+ "\tbootstrap=False, oob_score=False, n_jobs=None,\n"
+ "\trandom_state=None, verbose=0, warm_start=False, class_weight=None,\n"),
GradientBoostingClassifier("ensemble", true, false, true, false,
"\tinit=None, learning_rate=0.1, loss='deviance',\n"
+ "\tmax_depth=3, max_features=None, max_leaf_nodes=None,\n"
+ "\tmin_samples_leaf=1, min_samples_split=2,\n"
+ "\tmin_weight_fraction_leaf=0.0, n_estimators=100,\n"
+ "\trandom_state=None, subsample=1.0, verbose=0,\n"
+ "\twarm_start=False"),
GradientBoostingRegressor("ensemble", false, true, false, false,
"\talpha=0.9, init=None, learning_rate=0.1, loss='ls',\n"
+ "\tmax_depth=3, max_features=None, max_leaf_nodes=None,\n"
+ "\tmin_samples_leaf=1, min_samples_split=2,\n"
+ "\tmin_weight_fraction_leaf=0.0, n_estimators=100,\n"
+ "\trandom_state=None, subsample=1.0, verbose=0, warm_start=False"),
RandomForestClassifier("ensemble", true, false, true, false,
"\tbootstrap=True, class_weight=None, criterion='gini',\n"
+ "\tmax_depth=None, max_features='auto', max_leaf_nodes=None,\n"
+ "\tmin_samples_leaf=1, min_samples_split=2,\n"
+ "\tmin_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,\n"
+ "\toob_score=False, random_state=None, verbose=0,\n"
+ "\twarm_start=False"),
RandomForestRegressor("ensemble", false, true, false, false,
"\tbootstrap=True, criterion='mse', max_depth=None,\n"
+ "\tmax_features='auto', max_leaf_nodes=None, min_samples_leaf=1,\n"
+ "\tmin_samples_split=2, min_weight_fraction_leaf=0.0,\n"
+ "\tn_estimators=10, n_jobs=1, oob_score=False, random_state=None,\n"
+ "\tverbose=0, warm_start=False"),
// Not actually an sklearn scheme, but does have an sklearn API
XGBClassifier("xgboost", true, false, true, true,
"\tmax_depth=3, learning_rate=0.1, n_estimators=100, silent=True,\n"
+ "\tobjective='binary:logistic', booster='gbtree', n_jobs=1,\n"
+ "\tnthread=None, gamma=0, min_child_weight=1, max_delta_step=0,\n"
+ "\tsubsample=1, colsample_bytree=1, colsample_bylevel=1,\n"
+ "\treg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,\n"
+ "\trandom_state=0, seed=None, missing=None, tree_method='hist'"),
XGBRegressor("xgboost", false, true, false, true,
"\tmax_depth=3, learning_rate=0.1, n_estimators=100, silent=True,\n"
+ "\tobjective='reg:linear', booster='gbtree', n_jobs=1,\n"
+ "\tnthread=None, gamma=0, min_child_weight=1, max_delta_step=0,\n"
+ "\tsubsample=1, colsample_bytree=1, colsample_bylevel=1,\n"
+ "\treg_alpha=0, reg_lambda=1, scale_pos_weight=1, base_score=0.5,\n"
+ "\trandom_state=0, seed=None, missing=None, tree_method='hist'");
private String m_module;
private boolean m_classification;
private boolean m_regression;
private boolean m_producesProbabilities;
/**
* True to have the model variable cleared in python after training and
* batch prediction. This conserves memory, and might allow additional
* off-heap resources (e.g. GPU device memory) to be freed. However, there
* will be additional overhead in transferring the model back into python
* each time predictions are required.
*/
private boolean m_removeModelFromPyPostTraining;
private String m_defaultParameters;
/**
* Enum constructor
*
* @param module the scikit-learn module of the given scheme
* @param classification true if it is a classifier
* @param regression true if it is a regressor
* @param producesProbabilities true if it produces probabilities
* @param defaultParameters the list of default parameter settings
*/
Learner(String module, boolean classification, boolean regression,
boolean producesProbabilities, boolean removeModel,
String defaultParameters) {
m_module = module;
m_producesProbabilities = producesProbabilities;
m_classification = classification;
m_regression = regression;
m_removeModelFromPyPostTraining = removeModel;
m_defaultParameters = defaultParameters;
}
/**
* Get the scikit-learn module of this scheme
*
* @return the scikit-learn module of this scheme
*/
public String getModule() {
return m_module;
}
/**
* Default implementation of producesProbabilities given parameter settings.
* Specific enum values can override if there are specific parameter
* settings where probabilities can or can't be produced
*
* @param params the current parameter settings for the scheme
* @return true if probabilities can be produced given the parameter
* settings
*/
public boolean producesProbabilities(String params) {
return m_producesProbabilities;
}
/**
* Return true if the variable containing the model in python should be
* cleared after training and each batch prediction operation. This can be
* advantageous for specific methods that might either consume lots of host
* memory or device (e.g. GPU) memory. The disadvantage is that the model
* will need to be transferred into python prior to each batch prediction
* call (cross-validation will be slower).
*
* @return true if the model variable should be cleared in python
*/
public boolean removeModelFromPythonPostTrainPredict() {
return m_removeModelFromPyPostTraining;
}
/**
* Return true if this scheme is a classifier
*
* @return true if this scheme is a classifier
*/
public boolean isClassifier() {
return m_classification;
}
/**
* Return true if this scheme is a regressor
*
* @return true if this scheme is a regressor
*/
public boolean isRegressor() {
return m_regression;
}
/**
* Get the default settings for parameters for this scheme
*
* @return the default parameter settings of this scheme
*/
public String getDefaultParameters() {
return m_defaultParameters;
}
};
/** The tags for the GUI drop-down for learner selection */
public static final Tag[] TAGS_LEARNER = new Tag[Learner.values().length];
static {
for (Learner l : Learner.values()) {
TAGS_LEARNER[l.ordinal()] = new Tag(l.ordinal(), l.toString());
}
}
/**
* Holds the version number of scikit-learn. API for LDA and QDA changed in
* version 0.17.0, so we need to check for this in order to adjust code for
* these methods.
*/
protected double m_scikitVersion = -1;
/** Will be set to true if xgboost is available */
protected boolean m_xgboostInstalled;
/** The scikit learner to use */
protected Learner m_learner = Learner.DecisionTreeClassifier;
/** The parameters to pass to the learner */
protected String m_learnerOpts = "";
/** True if the supervised nominal to binary filter should be used */
protected boolean m_useSupervisedNominalToBinary;
/** The nominal to binary filter */
protected Filter m_nominalToBinary;
/** For replacing missing values */
protected Filter m_replaceMissing = new ReplaceMissingValues();
/** Holds the python serialized model */
protected String m_pickledModel;
/**
* If true, then the pickled model is not fetched from Python. Not fetching
* the model means that this classifier can't be serialized and used later.
* However, execution will be faster when running a cross-validation, and it
* is not necessary to fetch the model as it exists in the python environment
* for the duration of the virtual machine. In some cases, the size of the
* pickled model may exceed the current limits for transfer (Integer.MAX_VALUE
* bytes, about 2.1Gb).
*/
protected boolean m_dontFetchModelFromPython;
/** Holds the textual description of the scikit learner */
protected String m_learnerToString = "";
/** For making this model unique in python */
protected String m_modelHash;
/**
* True for nominal class labels that don't occur in the training data
*/
protected boolean[] m_nominalEmptyClassIndexes;
/**
* Fall back to Zero R if there are no instances with non-missing class or
* only the class is present in the data
*/
protected ZeroR m_zeroR;
/** Class priors for use if there are numerical problems */
protected double[] m_classPriors;
/** Default - use python in PATH */
protected String m_pyCommand = "default";
/** Default - use PATH as is for executing python */
protected String m_pyPath = "default";
/**
* Optional server name/ID by which to force a new server for this instance
* (or share amongst specific instances). The string "none" indicates no
* specific server name. Python command + server name uniquely identifies a
* server to use
*/
protected String m_serverID = "none";
/**
* Global help info
*
* @return the global help info for this scheme
*/
public String globalInfo() {
StringBuilder b = new StringBuilder();
b.append("A wrapper for classifiers implemented in the scikit-learn "
+ "python library. The following learners are available:\n\n");
for (Learner l : Learner.values()) {
b.append(l.toString()).append("\n");
b.append("[");
if (l.isClassifier()) {
b.append(" classification ");
}
if (l.isRegressor()) {
b.append(" regression ");
}
b.append("]").append("\nDefault parameters:\n");
b.append(l.getDefaultParameters()).append("\n");
}
return b.toString();
}
/**
* Batch prediction size (default: 100 instances)
*/
protected String m_batchPredictSize = "100";
/**
* Whether to try and continue after script execution reports output on system
* error from Python. Some schemes output warning (not error) messages to the
* sys err stream
*/
protected boolean m_continueOnSysErr;
/**
* Get the capabilities of this learner
*
* @param doServerChecks true if the python server checks should be run
* @return the Capabilities of this learner
*/
protected Capabilities getCapabilities(boolean doServerChecks) {
Capabilities result = super.getCapabilities();
result.disableAll();
if (doServerChecks) {
boolean pythonAvailable = true;
PythonSession session = null;
try {
session = getSession();
} catch (WekaException e) {
pythonAvailable = false;
}
if (pythonAvailable) {
if (m_scikitVersion < 0) {
// try and establish scikit-learn version
try {
// PythonSession session = PythonSession.acquireSession(this);
String script = "import sklearn\nskv = sklearn.__version__\n";
List outAndErr = session.executeScript(script, getDebug());
String versionNumber =
session.getVariableValueFromPythonAsPlainString("skv", getDebug());
if (versionNumber != null && versionNumber.length() > 0) {
// strip minor version
versionNumber = versionNumber.substring(0, versionNumber.lastIndexOf('.'));
m_scikitVersion = Double.parseDouble(versionNumber);
}
// check for xgboost
m_xgboostInstalled = true;
script = "import xgboost\n";
outAndErr = session.executeScript(script, getDebug());
if (outAndErr.get(1).length() > 0) {
m_xgboostInstalled = false;
}
} catch (WekaException e) {
m_xgboostInstalled = false;
e.printStackTrace();
} finally {
// PythonSession.releaseSession(this);
try {
releaseSession();
session = null;
} catch (WekaException e) {
e.printStackTrace();
}
}
}
result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
result.enable(Capabilities.Capability.MISSING_VALUES);
result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
if (m_learner.isClassifier()) {
result.enable(Capabilities.Capability.BINARY_CLASS);
result.enable(Capabilities.Capability.NOMINAL_CLASS);
}
if (m_learner.isRegressor()) {
result.enable(Capabilities.Capability.NUMERIC_CLASS);
}
if (session != null) {
try {
releaseSession();
} catch (WekaException e) {
e.printStackTrace();
}
}
} else {
try {
String pyCheckResults = getPythonEnvCheckResults();
throw new RuntimeException(String.format("The python environment is either not available or is not configured correctly:\n\n%s", pyCheckResults));
} catch (WekaException e) {
throw new RuntimeException("The python environment is either not available or is not configured correctly", e);
}
}
} else {
result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
result.enable(Capabilities.Capability.MISSING_VALUES);
result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
if (m_learner.isClassifier()) {
result.enable(Capabilities.Capability.BINARY_CLASS);
result.enable(Capabilities.Capability.NOMINAL_CLASS);
}
if (m_learner.isRegressor()) {
result.enable(Capabilities.Capability.NUMERIC_CLASS);
}
}
return result;
}
/**
* Get the capabilities of this learner
*
* @return the capabilities of this scheme
*/
@Override
public Capabilities getCapabilities() {
// note that we don't do the server checks from here because the GenericObjectEditor
// in the GUI calls this method for every character typed in a text field!
return getCapabilities(false);
}
/**
* Get whether to use the supervised version of nominal to binary
*
* @return true if supervised nominal to binary is to be used
*/
@OptionMetadata(displayName = "Use supervised nominal to binary conversion",
description = "Use supervised nominal to binary conversion of nominal attributes.",
commandLineParamName = "S", commandLineParamSynopsis = "-S",
commandLineParamIsFlag = true, displayOrder = 3)
public boolean getUseSupervisedNominalToBinary() {
return m_useSupervisedNominalToBinary;
}
/**
* Set whether to use the supervised version of nominal to binary
*
* @param useSupervisedNominalToBinary true if supervised nominal to binary is
* to be used
*/
public void
setUseSupervisedNominalToBinary(boolean useSupervisedNominalToBinary) {
m_useSupervisedNominalToBinary = useSupervisedNominalToBinary;
}
/**
* Get the scikit-learn scheme to use
*
* @return the scikit-learn scheme to use
*/
@OptionMetadata(displayName = "Scikit-learn learner",
description = "Scikit-learn learner to use.\nAvailable learners:\nDecisionTreeClassifier"
+ ", DecisionTreeRegressor, GaussianNB, MultinomialNB,"
+ "BernoulliNB, LDA, QDA, "
+ "LogisticRegression, LogisticRegressionCV,\n"
+ "LinearRegression, ARDRegression, "
+ "BayesianRidge, ElasticNet, Lars,\nLarsCV, Lasso, LassoCV, LassoLars, "
+ "LassoLarsCV, LassoLarsIC, MLPClassifier, MLPRegressor, OrthogonalMatchingPursuit,\n"
+ "OrthogonalMatchingPursuitCV, PassiveAggressiveClassifier, "
+ "PassiveAggressiveRegressor, Perceptron, RANSACRegressor,\nRidge, "
+ "RidgeClassifier, RidgeClassifierCV, RidgeCV, SGDClassifier,\nSGDRegressor,"
+ "TheilSenRegressor, GaussianProcess, KernelRidge, KNeighborsClassifier, "
+ "\nRadiusNeighborsClassifier, KNeighborsRegressor, RadiusNeighborsRegressor, SVC,"
+ "\nLinearSVC, NuSVC, SVR, NuSVR, AdaBoostClassifier, AdaBoostRegressor,"
+ "BaggingClassifier, BaggingRegressor,\nExtraTreeClassifier, ExtraTreeRegressor,"
+ "GradientBoostingClassifier, GradientBoostingRegressor,\n"
+ "RandomForestClassifier, RandomForestRegressor, XGBClassifier, XGBRegressor."
+ "\n(default = DecisionTreeClassifier)",
commandLineParamName = "learner",
commandLineParamSynopsis = "-learner ", displayOrder = 1)
public SelectedTag getLearner() {
return new SelectedTag(m_learner.ordinal(), TAGS_LEARNER);
}
/**
* Set the scikit-learn scheme to use
*
* @param learner the scikit-learn scheme to use
*/
public void setLearner(SelectedTag learner) {
int learnerID = learner.getSelectedTag().getID();
for (Learner l : Learner.values()) {
if (l.ordinal() == learnerID) {
m_learner = l;
break;
}
}
}
/**
* Get the parameters to pass to the scikit-learn scheme
*
* @return the parameters to use
*/
@OptionMetadata(displayName = "Learner parameters",
description = "learner parameters to use", displayOrder = 2,
commandLineParamName = "parameters",
commandLineParamSynopsis = "-parameters ")
public String getLearnerOpts() {
return m_learnerOpts;
}
/**
* Set the parameters to pass to the scikit-learn scheme
*
* @param opts the parameters to use
*/
public void setLearnerOpts(String opts) {
m_learnerOpts = opts;
}
@Override
public void setBatchSize(String size) {
m_batchPredictSize = size;
}
@OptionMetadata(displayName = "Batch size", description = "The preferred "
+ "number of instances to transfer into python for prediction\n(if operating"
+ "in batch prediction mode). More or fewer instances than this will be "
+ "accepted.", commandLineParamName = "batch",
commandLineParamSynopsis = "-batch ", displayOrder = 4)
@Override
public String getBatchSize() {
return m_batchPredictSize;
}
/**
* Returns true, as we send entire test sets over to python for prediction
*
* @return true
*/
@Override
public boolean implementsMoreEfficientBatchPrediction() {
return true;
}
/**
* Set whether to try and continue after seeing output on the sys error
* stream. Some schemes write warnings (rather than errors) to sys error.
*
* @param c true if we should try to continue after seeing output on the sys
* error stream
*/
public void setContinueOnSysErr(boolean c) {
m_continueOnSysErr = c;
}
/**
* Get whether to try and continue after seeing output on the sys error
* stream. Some schemes write warnings (rather than errors) to sys error.
*
* @return true if we should try to continue after seeing output on the sys
* error stream
*/
@OptionMetadata(
displayName = "Try to continue after sys err output from script",
description = "Try to continue after sys err output from script.\nSome schemes"
+ " report warnings to the system error stream.",
displayOrder = 5, commandLineParamName = "continue-on-err",
commandLineParamSynopsis = "-continue-on-err",
commandLineParamIsFlag = true)
public boolean getContinueOnSysErr() {
return m_continueOnSysErr;
}
/**
* If true then don't retrieve the model from python. This can speed up
* cross-validation but will prevent the trained classifier from being used
* after deserialization. May also be necessary if a model in python is very
* large and exceeds the current transfer size (Integer.MAX_VALUE bytes, about
* 2.1Gb).
*
* @param dontFetchModelFromPython true to not fetch the model from python.
*/
public void setDontFetchModelFromPython(boolean dontFetchModelFromPython) {
m_dontFetchModelFromPython = dontFetchModelFromPython;
}
/**
* If true then don't retrieve the model from python. This can speed up
* cross-validation but will prevent the trained classifier from being used
* after deserialization. May also be necessary if a model in python is very
* large and exceeds the current transfer size (Integer.MAX_VALUE bytes, about
* 2.1Gb).
*
* @return true if not fetching the model from python.
*/
@OptionMetadata(displayName = "Don't retrieve model from python",
description = "Don't retrieve the model from python - speeds up "
+ "cross-validation,\nbut prevents this classifier from being "
+ "used after deserialization.\nSome models in python (e.g. large "
+ "random forests) may exceed the maximum size for transfer\n("
+ "currently Integer.MAX_VALUE bytes)",
displayOrder = 6, commandLineParamName = "dont-fetch-model",
commandLineParamSynopsis = "-dont-fetch-model",
commandLineParamIsFlag = true)
public boolean getDontFetchModelFromPython() {
return m_dontFetchModelFromPython;
}
/**
* Set the python command to use. Empty string or "default" indicate that the
* python present in the PATH should be used.
*
* @param pyCommand the path to the python executable (or empty
* string/"default" to use python in the PATH)
*/
public void setPythonCommand(String pyCommand) {
m_pyCommand = pyCommand;
}
/**
* Get the python command to use. Empty string or "default" indicate that the
* python present in the PATH should be used.
*
* @return the path to the python executable (or empty string/"default" to use
* python in the PATH)
*/
@OptionMetadata(displayName = "Python command ",
description = "Path to python executable ('default' to use python in the PATH)",
commandLineParamName = "py-command",
commandLineParamSynopsis = "-py-command ",
displayOrder = 7)
public String getPythonCommand() {
return m_pyCommand;
}
/**
* Set optional entries to prepend to the PATH so that python can execute
* correctly. Only applies when not using a default python.
*
* @param pythonPath additional entries to prepend to the PATH
*/
public void setPythonPath(String pythonPath) {
m_pyPath = pythonPath;
}
/**
* Get optional entries to prepend to the PATH so that python can execute
* correctly. Only applies when not using a default python.
*
* @return additional entries to prepend to the PATH
*/
@OptionMetadata(displayName = "Python path",
description = "Optional elements to prepend to the PATH so that python can "
+ "execute correctly ('default' to use PATH as-is)",
commandLineParamName = "py-path",
commandLineParamSynopsis = "-py-path ", displayOrder = 8)
public String getPythonPath() {
return m_pyPath;
}
/**
* Set an optional server name by which to identify the python server to use.
* Can be used share a given server amongst selected instances or reserve a
* server specifically for this classifier. Python command + server name
* uniquely identifies the server to use.
*
* @param serverID the name of the server to use (or none for no specific
* server name).
*/
public void setServerID(String serverID) {
m_serverID = serverID;
}
/**
* Get an optional server name by which to identify the python server to use.
* Can be used share a given server amongst selected instances or reserve a
* server specifically for this classifier. Python command + server name
* uniquely identifies the server to use.
*
* @return the name of the server to use (or none) for no specific server
* name.
*/
@OptionMetadata(displayName = "Server name/ID",
description = "Optional name to identify this server, can be used to share "
+ "a given server instance - default = 'none' (i.e. no server name)",
commandLineParamName = "server-name",
commandLineParamSynopsis = "-server-name ",
displayOrder = 9)
public String getServerID() {
return m_serverID;
}
/**
* Gets a python session object to use for interacting with python
*
* @return a PythonSession object
* @throws WekaException if a problem occurs
*/
protected PythonSession getSession() throws WekaException {
PythonSession session = null;
String pyCommand = m_pyCommand != null && m_pyCommand.length() > 0
&& !m_pyCommand.equalsIgnoreCase("default") ? m_pyCommand : null;
String pyPath = m_pyPath != null && m_pyPath.length() > 0
&& !m_pyPath.equalsIgnoreCase("default") ? m_pyPath : null;
String serverID = m_serverID != null && m_serverID.length() > 0
&& !m_serverID.equalsIgnoreCase("none") ? m_serverID : null;
if (pyCommand != null) {
if (!PythonSession.pythonAvailable(pyCommand, serverID)) {
// try to create this environment/server
// System.err.println("Starting server: " + pyCommand + " " + serverID);
if (!PythonSession.initSession(pyCommand, serverID, pyPath, getDebug())) {
String envEvalResults =
PythonSession.getPythonEnvCheckResults(pyCommand, serverID);
throw new WekaException("Was unable to start python environment ("
+ pyCommand + ")\n\n" + envEvalResults);
}
}
session = PythonSession.acquireSession(pyCommand, serverID, this);
} else {
if (!PythonSession.pythonAvailable()) {
// try initializing
if (!PythonSession.initSession("python", getDebug())) {
String envEvalResults = PythonSession.getPythonEnvCheckResults();
throw new WekaException(
"Was unable to start python environment:\n\n" + envEvalResults);
}
}
session = PythonSession.acquireSession(this);
}
return session;
}
/**
* Release the python session
*
* @throws WekaException if a problem occurs
*/
protected void releaseSession() throws WekaException {
PythonSession session = null;
String pyCommand = m_pyCommand != null && m_pyCommand.length() > 0
&& !m_pyCommand.equalsIgnoreCase("default") ? m_pyCommand : null;
String serverID = m_serverID != null && m_serverID.length() > 0
&& !m_serverID.equalsIgnoreCase("none") ? m_serverID : null;
if (pyCommand != null) {
if (PythonSession.pythonAvailable(pyCommand, serverID)) {
PythonSession.releaseSession(pyCommand, serverID, this);
}
} else {
if (PythonSession.pythonAvailable()) {
PythonSession.releaseSession(this);
}
}
}
/**
* Get the results of executing the environment check in python
*
* @return the results of executing the environment check
* @throws WekaException if a problem occurs
*/
protected String getPythonEnvCheckResults() throws WekaException {
String result = "";
String pyCommand = m_pyCommand != null && m_pyCommand.length() > 0
&& !m_pyCommand.equalsIgnoreCase("default") ? m_pyCommand : null;
String serverID = m_serverID != null && m_serverID.length() > 0
&& !m_serverID.equalsIgnoreCase("none") ? m_serverID : null;
if (pyCommand != null) {
result = PythonSession.getPythonEnvCheckResults(pyCommand, serverID);
} else {
result = PythonSession.getPythonEnvCheckResults();
}
return result;
}
/**
* Check to see if python is available for the user-specified environment.
*
* @return
*/
protected boolean pythonAvailable() {
String pyCommand = m_pyCommand != null && m_pyCommand.length() > 0
&& !m_pyCommand.equalsIgnoreCase("default") ? m_pyCommand : null;
String serverID = m_serverID != null && m_serverID.length() > 0
&& !m_serverID.equalsIgnoreCase("none") ? m_serverID : null;
if (pyCommand != null) {
return PythonSession.pythonAvailable(pyCommand, serverID);
}
return PythonSession.pythonAvailable();
}
/**
* Build the classifier
*
* @param data set of instances serving as training data
* @throws Exception if a problem occurs
*/
@Override
public void buildClassifier(Instances data) throws Exception {
m_pickledModel = null;
getCapabilities(true).testWithFail(data);
m_zeroR = null;
PythonSession session = getSession();
if (m_learner == Learner.XGBClassifier && !m_xgboostInstalled) {
throw new Exception(
"xgboost does not seem to be available in your python "
+ "installation");
}
if (m_modelHash == null) {
m_modelHash = "" + hashCode();
}
data = new Instances(data);
data.deleteWithMissingClass();
m_zeroR = new ZeroR();
m_zeroR.buildClassifier(data);
m_classPriors = data.numInstances() > 0
? m_zeroR.distributionForInstance(data.instance(0))
: new double[data.classAttribute().numValues()];
if (data.numInstances() == 0 || data.numAttributes() == 1) {
if (data.numInstances() == 0) {
System.err
.println("No instances with non-missing class - using ZeroR model");
} else {
System.err.println("Only the class attribute is present in "
+ "the data - using ZeroR model");
}
return;
} else {
m_zeroR = null;
}
if (data.classAttribute().isNominal()) {
// check for empty classes
AttributeStats stats = data.attributeStats(data.classIndex());
m_nominalEmptyClassIndexes =
new boolean[data.classAttribute().numValues()];
for (int i = 0; i < stats.nominalWeights.length; i++) {
if (stats.nominalWeights[i] == 0) {
m_nominalEmptyClassIndexes[i] = true;
}
}
}
m_replaceMissing.setInputFormat(data);
data = Filter.useFilter(data, m_replaceMissing);
if (getUseSupervisedNominalToBinary()) {
m_nominalToBinary =
new weka.filters.supervised.attribute.NominalToBinary();
} else {
m_nominalToBinary =
new weka.filters.unsupervised.attribute.NominalToBinary();
}
m_nominalToBinary.setInputFormat(data);
data = Filter.useFilter(data, m_nominalToBinary);
try {
String learnerModule = m_learner.getModule();
String learnerMethod = m_learner.toString();
if (learnerMethod.equalsIgnoreCase("MLPClassifier")
|| learnerMethod.equalsIgnoreCase("MLPRegressor")) {
if (m_scikitVersion < 0.18) {
throw new Exception(
learnerMethod + " is not available in scikit-learn " + "version "
+ m_scikitVersion + ". Version 0.18 or higher is " + "required.");
}
}
// PythonSession session = PythonSession.acquireSession(this);
// transfer the data over to python
session.instancesToPythonAsScikitLearn(data, TRAINING_DATA_ID,
getDebug());
StringBuilder learnScript = new StringBuilder();
if (m_learner == Learner.XGBClassifier
|| m_learner == Learner.XGBRegressor) {
learnScript.append("import xgboost as xgb\n")
.append("import numpy as np").append("\n");
} else {
learnScript.append("from sklearn import *").append("\n")
.append("import numpy as np").append("\n");
}
if (m_scikitVersion > 0.18) {
if (learnerMethod.equalsIgnoreCase("LDA")) {
learnerMethod = "LinearDiscriminantAnalysis";
learnerModule = "discriminant_analysis";
} else if (learnerMethod.equalsIgnoreCase("QDA")) {
learnerMethod = "QuadraticDiscriminantAnalysis";
learnerModule = "discriminant_analysis";
}
}
if (m_learner == Learner.XGBClassifier
|| m_learner == Learner.XGBRegressor) {
learnScript.append(MODEL_ID + m_modelHash + " = xgb." + learnerMethod
+ "(" + getLearnerOpts() + ")").append("\n");
} else {
learnScript
.append(MODEL_ID + m_modelHash + " = " + learnerModule + "."
+ learnerMethod + "("
+ (getLearnerOpts().length() > 0 ? getLearnerOpts() : "") + ")")
.append("\n");
}
learnScript.append(MODEL_ID + m_modelHash + ".fit(X,np.ravel(Y))")
.append("\n");
List outAndErr =
session.executeScript(learnScript.toString(), getDebug());
if (outAndErr.size() == 2 && outAndErr.get(1).length() > 0) {
if (m_continueOnSysErr) {
System.err.println(outAndErr.get(1));
} else {
throw new Exception(outAndErr.get(1));
}
}
m_learnerToString = session.getVariableValueFromPythonAsPlainString(
MODEL_ID + m_modelHash, getDebug());
if (!getDontFetchModelFromPython()) {
// retrieve the model from python
m_pickledModel = session.getVariableValueFromPythonAsPickledObject(
MODEL_ID + m_modelHash, getDebug());
}
if (m_learner.removeModelFromPythonPostTrainPredict()) {
String cleanUp = "del " + MODEL_ID + m_modelHash + "\n";
outAndErr = session.executeScript(cleanUp, getDebug());
if (outAndErr.size() == 2 && outAndErr.get(1).length() > 0) {
if (m_continueOnSysErr) {
System.err.println(outAndErr.get(1));
} else {
throw new Exception(outAndErr.get(1));
}
}
}
// release session
} finally {
releaseSession();
}
}
private double[][] batchScoreWithZeroR(Instances insts) throws Exception {
double[][] result = new double[insts.numInstances()][];
for (int i = 0; i < insts.numInstances(); i++) {
Instance current = insts.instance(i);
result[i] = m_zeroR.distributionForInstance(current);
}
return result;
}
/**
* Return the predicted probabilities for the supplied instance
*
* @param instance the instance to be classified
* @return the predicted probabilities
* @throws Exception if a problem occurs
*/
@Override
public double[] distributionForInstance(Instance instance) throws Exception {
Instances temp = new Instances(instance.dataset(), 0);
temp.add(instance);
return distributionsForInstances(temp)[0];
}
/**
* Return the predicted probabilities for the supplied instances
*
* @param insts the instances to get predictions for
* @return the predicted probabilities for the supplied instances
* @throws Exception if a problem occurs
*/
@Override
@SuppressWarnings("unchecked")
public double[][] distributionsForInstances(Instances insts)
throws Exception {
if (m_zeroR != null) {
return batchScoreWithZeroR(insts);
}
insts = Filter.useFilter(insts, m_replaceMissing);
insts = Filter.useFilter(insts, m_nominalToBinary);
Attribute classAtt = insts.classAttribute();
// remove the class attribute
Remove r = new Remove();
r.setAttributeIndices("" + (insts.classIndex() + 1));
r.setInputFormat(insts);
insts = Filter.useFilter(insts, r);
insts.setClassIndex(-1);
double[][] results = null;
PythonSession session = null;
try {
session = getSession();
session.instancesToPythonAsScikitLearn(insts, TEST_DATA_ID, getDebug());
StringBuilder predictScript = new StringBuilder();
// check if model exists in python. If not, then transfer it over
if (!session.checkIfPythonVariableIsSet(MODEL_ID + m_modelHash,
getDebug())) {
if (m_pickledModel == null || m_pickledModel.length() == 0) {
throw new Exception("There is no model to transfer into Python!");
}
session.setPythonPickledVariableValue(MODEL_ID + m_modelHash,
m_pickledModel, getDebug());
}
String learnerModule = m_learner.getModule();
String learnerMethod = m_learner.toString();
if (m_scikitVersion > 0.18) {
if (learnerMethod.equalsIgnoreCase("LDA")) {
learnerMethod = "LinearDiscriminantAnalysis";
learnerModule = "discriminant_analysis";
} else if (learnerMethod.equalsIgnoreCase("QDA")) {
learnerMethod = "QuadraticDiscriminantAnalysis";
learnerModule = "discriminant_analysis";
}
}
if (m_learner == Learner.XGBClassifier
|| m_learner == Learner.XGBRegressor) {
predictScript.append("import xgboost as xgb\n");
} else {
predictScript
.append("from sklearn." + learnerModule + " import " + learnerMethod)
.append("\n");
}
predictScript.append("preds = " + MODEL_ID + m_modelHash + ".predict"
+ (m_learner.producesProbabilities(m_learnerOpts) ? "_proba" : "")
+ "(X)").append("\npreds = preds.tolist()\n");
List outAndErr =
session.executeScript(predictScript.toString(), getDebug());
if (outAndErr.size() == 2 && outAndErr.get(1).length() > 0) {
if (m_continueOnSysErr) {
System.err.println(outAndErr.get(1));
} else {
throw new Exception(outAndErr.get(1));
}
}
List