hivemall.xgboost.XGBoostTrainUDTF Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.xgboost;
import hivemall.UDTFWithOptions;
import hivemall.annotations.VisibleForTesting;
import hivemall.utils.collections.lists.FloatArrayList;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.OptionUtils;
import hivemall.utils.math.MathUtils;
import hivemall.xgboost.utils.DMatrixBuilder;
import hivemall.xgboost.utils.DenseDMatrixBuilder;
import hivemall.xgboost.utils.NativeLibLoader;
import hivemall.xgboost.utils.SparseDMatrixBuilder;
import hivemall.xgboost.utils.XGBoostUtils;
import matrix4j.utils.lang.ArrayUtils;
import matrix4j.utils.lang.Primitives;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.Text;
/**
* UDTF for train_xgboost
*/
//@formatter:off
@Description(name = "train_xgboost",
value = "_FUNC_(array features, target, const string options)"
+ " - Returns a relation consists of pred_model>",
extended = "SELECT \n" +
" train_xgboost(features, label, '-objective binary:logistic -iters 10') \n" +
" as (model_id, model)\n" +
"from (\n" +
" select features, label\n" +
" from xgb_input\n" +
" cluster by rand(43) -- shuffle\n" +
") shuffled;")
//@formatter:on
public class XGBoostTrainUDTF extends UDTFWithOptions {
private static final Log logger = LogFactory.getLog(XGBoostTrainUDTF.class);
// Settings for the XGBoost native library
static {
NativeLibLoader.initXGBoost();
}
// For input parameters
private ListObjectInspector featureListOI;
private PrimitiveObjectInspector featureElemOI;
private PrimitiveObjectInspector targetOI;
// For training input buffering
private boolean denseInput;
private DMatrixBuilder matrixBuilder;
private FloatArrayList labels;
// For XGBoost options
@Nonnull
protected final Map params = new HashMap();
protected int numClass;
protected ObjectiveType objectiveType = null;
public enum ObjectiveType {
regression, binary, multiclass, rank, other;
@Nonnull
public static ObjectiveType resolve(@Nonnull String objective) {
if (objective.startsWith("reg:")) {
return regression;
} else if (objective.startsWith("binary:")) {
return binary;
} else if (objective.startsWith("multi:")) {
return multiclass;
} else if (objective.startsWith("rank:")) {
return rank;
} else {
return other;
}
}
}
public XGBoostTrainUDTF() {}
@Override
protected Options getOptions() {
final Options opts = new Options();
opts.addOption("num_round", "iters", true, "Number of boosting iterations [default: 10]");
opts.addOption("maximize_evaluation_metrics", true,
"Maximize evaluation metrics [default: false]");
opts.addOption("num_early_stopping_rounds", true,
"Minimum rounds required for early stopping [default: 0]");
opts.addOption("validation_ratio", true,
"Validation ratio in range [0.0,1.0] [default: 0.2]");
/** General parameters */
opts.addOption("booster", true,
"Set a booster to use, gbtree or gblinear or dart. [default: gbree]");
opts.addOption("silent", true, "Deprecated. Please use verbosity instead. "
+ "0 means printing running messages, 1 means silent mode [default: 1]");
opts.addOption("verbosity", true, "Verbosity of printing messages. "
+ "Choices: 0 (silent), 1 (warning), 2 (info), 3 (debug). [default: 0]");
opts.addOption("disable_default_eval_metric", true,
"NFlag to disable default metric. Set to >0 to disable. [default: 0]");
opts.addOption("num_pbuffer", true,
"Size of prediction buffer [default: set automatically by xgboost]");
opts.addOption("num_feature", true,
"Feature dimension used in boosting [default: set automatically by xgboost]");
/** Parameters among Boosters */
opts.addOption("lambda", "reg_lambda", true,
"L2 regularization term on weights. Increasing this value will make model more conservative."
+ " [default: 1.0 for gbtree, 0.0 for gblinear]");
opts.addOption("alpha", "reg_alpha", true,
"L1 regularization term on weights. Increasing this value will make model more conservative."
+ " [default: 0.0]");
opts.addOption("updater", true,
"A comma-separated string that defines the sequence of tree updaters to run. "
+ "For a full list of valid inputs, please refer to XGBoost Parameters."
+ " [default: 'grow_colmaker,prune' for gbtree, 'shotgun' for gblinear]");
/** Parameters for Tree Booster */
opts.addOption("eta", "learning_rate", true,
"Step size shrinkage used in update to prevents overfitting [default: 0.3]");
opts.addOption("gamma", "min_split_loss", true,
"Minimum loss reduction required to make a further partition on a leaf node of the tree."
+ " [default: 0.0]");
opts.addOption("max_depth", true, "Max depth of decision tree [default: 6]");
opts.addOption("min_child_weight", true,
"Minimum sum of instance weight (hessian) needed in a child [default: 1.0]");
opts.addOption("max_delta_step", true,
"Maximum delta step we allow each tree's weight estimation to be [default: 0]");
opts.addOption("subsample", true,
"Subsample ratio of the training instance in range (0.0,1.0] [default: 1.0]");
opts.addOption("colsample_bytree", true,
"Subsample ratio of columns when constructing each tree [default: 1.0]");
opts.addOption("colsample_bylevel", true,
"Subsample ratio of columns for each level [default: 1.0]");
opts.addOption("colsample_bynode", true,
"Subsample ratio of columns for each node [default: 1.0]");
// tree_method
opts.addOption("tree_method", true,
"The tree construction algorithm used in XGBoost. [default: auto, Choices: auto, exact, approx, hist]");
opts.addOption("sketch_eps", true,
"This roughly translates into O(1 / sketch_eps) number of bins. \n"
+ "Compared to directly select number of bins, this comes with theoretical guarantee with sketch accuracy.\n"
+ "Only used for tree_method=approx. Usually user does not have to tune this. [default: 0.03]");
opts.addOption("scale_pos_weight", true,
"ontrol the balance of positive and negative weights, useful for unbalanced classes. "
+ "A typical value to consider: sum(negative instances) / sum(positive instances)"
+ " [default: 1.0]");
opts.addOption("refresh_leaf", true,
"This is a parameter of the refresh updater plugin. "
+ "When this flag is 1, tree leafs as well as tree nodes’ stats are updated. "
+ "When it is 0, only node stats are updated. [default: 1]");
opts.addOption("process_type", true,
"A type of boosting process to run. [Choices: default, update]");
opts.addOption("grow_policy", true,
"Controls a way new nodes are added to the tree. Currently supported only if tree_method is set to hist."
+ " [default: depthwise, Choices: depthwise, lossguide]");
opts.addOption("max_leaves", true,
"Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set. [default: 0]");
opts.addOption("max_bin", true,
"Maximum number of discrete bins to bucket continuous features. Only used if tree_method is set to hist."
+ " [default: 256]");
opts.addOption("num_parallel_tree", true,
"Number of parallel trees constructed during each iteration. This option is used to support boosted random forest. "
+ "Usually no need to tune (default 1 is enough) for gradient boosting trees."
+ " [default: 1]");
/** Parameters for Dart Booster (booster=dart) */
opts.addOption("sample_type", true,
"Type of sampling algorithm. [Choices: uniform (default), weighted]");
opts.addOption("normalize_type", true,
"Type of normalization algorithm. [Choices: tree (default), forest]");
opts.addOption("rate_drop", true, "Dropout rate in range [0.0, 1.0]. [default: 0.0]");
opts.addOption("one_drop", true,
"When this flag is enabled, at least one tree is always dropped during the dropout. "
+ "0 or 1. [default: 0]");
opts.addOption("skip_drop", true,
"Probability of skipping the dropout procedure during a boosting iteration "
+ "in range [0.0, 1.0]. [default: 0.0]");
/** Parameters for Linear Booster (booster=gblinear) */
opts.addOption("lambda_bias", true, "L2 regularization term on bias [default: 0.0]");
opts.addOption("feature_selector", true, "Feature selection and ordering method. "
+ "[Choices: cyclic (default), shuffle, random, greedy, thrifty]");
opts.addOption("top_k", true,
"The number of top features to select in greedy and thrifty feature selector. "
+ "The value of 0 means using all the features. [default: 0]");
/** Parameters for Tweedie Regression (objective=reg:tweedie) */
opts.addOption("tweedie_variance_power", true,
"Parameter that controls the variance of the Tweedie distribution in range [1.0, 2.0]."
+ " [default: 1.5]");
/** Learning Task Parameters */
opts.addOption("objective", true,
"Specifies the learning task and the corresponding learning objective. "
+ "Examples: reg:linear, reg:logistic, multi:softmax. "
+ "For a full list of valid inputs, refer to XGBoost Parameters. "
+ "[default: reg:linear]");
opts.addOption("base_score", true,
"Initial prediction score of all instances, global bias [default: 0.5]");
opts.addOption("eval_metric", true,
"Evaluation metrics for validation data. A default metric is assigned according to the objective:\n"
+ "- rmse: for regression\n" + "- error: for classification\n"
+ "- map: for ranking\n"
+ "For a list of valid inputs, see XGBoost Parameters.");
opts.addOption("seed", true, "Random number seed. [default: 43]");
opts.addOption("num_class", true, "Number of classes to classify");
return opts;
}
@Nonnull
@Override
protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
final CommandLine cl;
if (argOIs.length >= 3) {
String rawArgs = HiveUtils.getConstString(argOIs, 2);
cl = parseOptions(rawArgs);
} else {
cl = parseOptions(""); // use default options
}
String objective = cl.getOptionValue("objective");
if (objective == null) {
showHelp("Please provide \"-objective XXX\" option in the 3rd argument.\n\n"
+ "Here is the list of supported objectives: \n"
+ " - Regression:\n {reg:squarederror, reg:logistic, reg:gamma, reg:tweedie}\n"
+ " - Binary classification: {binary:logistic, binary:logitraw, binary:hinge}\n"
+ " - Multiclass classification:\n {multi:softmax, multi:softprob}\n"
+ " - Ranking:\n {rank:pairwise, rank:ndcg, rank:map}\n"
+ " - Other:\n {count:poisson, survival:cox}");
}
if (objective.equals("reg:squarederror")) {
// reg:linear is deprecated synonym of reg:squarederror
// however, reg:squarederror is not supported in xgboost-predictor yet
// https://github.com/dmlc/xgboost/pull/4267
objective = "reg:linear";
}
final String booster = cl.getOptionValue("booster", "gbtree");
int numRound = Primitives.parseInt(cl.getOptionValue("num_round"), 10);
params.put("num_round", numRound);
params.put("maximize_evaluation_metrics",
Primitives.parseBoolean(cl.getOptionValue("maximize_evaluation_metrics"), false));
params.put("num_early_stopping_rounds",
Primitives.parseInt(cl.getOptionValue("num_early_stopping_rounds"), 0));
double validationRatio =
Primitives.parseDouble(cl.getOptionValue("validation_ratio"), 0.2d);
if (validationRatio < 0.d || validationRatio >= 1.d) {
throw new UDFArgumentException("Invalid validation_ratio=" + validationRatio);
}
params.put("validation_ratio", validationRatio);
/** General parameters */
params.put("booster", booster);
params.put("silent", Primitives.parseInt(cl.getOptionValue("silent"), 1));
params.put("verbosity", Primitives.parseInt(cl.getOptionValue("verbosity"), 0));
params.put("nthread", Primitives.parseInt(cl.getOptionValue("nthread"), 1));
params.put("disable_default_eval_metric",
Primitives.parseInt(cl.getOptionValue("disable_default_eval_metric"), 0));
if (cl.hasOption("num_pbuffer")) {
params.put("num_pbuffer", Integer.valueOf(cl.getOptionValue("num_pbuffer")));
}
if (cl.hasOption("num_feature")) {
params.put("num_feature", Integer.valueOf(cl.getOptionValue("num_feature")));
}
/** Parameters for Tree Booster (booster=gbtree) */
if (booster.equals("gbtree")) {
params.put("eta", Primitives.parseDouble(cl.getOptionValue("eta"), 0.3d));
params.put("gamma", Primitives.parseDouble(cl.getOptionValue("gamma"), 0.d));
params.put("max_depth", Primitives.parseInt(cl.getOptionValue("max_depth"), 6));
params.put("min_child_weight",
Primitives.parseDouble(cl.getOptionValue("min_child_weight"), 1.d));
params.put("max_delta_step",
Primitives.parseDouble(cl.getOptionValue("max_delta_step"), 0.d));
params.put("subsample", Primitives.parseDouble(cl.getOptionValue("subsample"), 1.d));
params.put("colsamle_bytree",
Primitives.parseDouble(cl.getOptionValue("colsample_bytree"), 1.d));
params.put("colsamle_bylevel",
Primitives.parseDouble(cl.getOptionValue("colsamle_bylevel"), 1.d));
params.put("colsamle_bynode",
Primitives.parseDouble(cl.getOptionValue("colsamle_bynode"), 1.d));
params.put("lambda", Primitives.parseDouble(cl.getOptionValue("lambda"), 1.d));
params.put("alpha", Primitives.parseDouble(cl.getOptionValue("alpha"), 0.d));
params.put("tree_method", cl.getOptionValue("tree_method", "auto"));
params.put("sketch_eps",
Primitives.parseDouble(cl.getOptionValue("sketch_eps"), 0.03d));
params.put("scale_pos_weight",
Primitives.parseDouble(cl.getOptionValue("scale_pos_weight"), 1.d));
params.put("updater", cl.getOptionValue("updater", "grow_colmaker,prune"));
params.put("refresh_leaf", Primitives.parseInt(cl.getOptionValue("refresh_leaf"), 1));
params.put("process_type", cl.getOptionValue("process_type", "default"));
params.put("grow_policy", cl.getOptionValue("grow_policy", "depthwise"));
params.put("max_leaves", Primitives.parseInt(cl.getOptionValue("max_leaves"), 0));
params.put("max_bin", Primitives.parseInt(cl.getOptionValue("max_bin"), 256));
params.put("num_parallel_tree",
Primitives.parseInt(cl.getOptionValue("num_parallel_tree"), 1));
}
/** Parameters for Dart Booster (booster=dart) */
if (booster.equals("dart")) {
params.put("sample_type", cl.getOptionValue("sample_type", "uniform"));
params.put("normalize_type", cl.getOptionValue("normalize_type", "tree"));
params.put("rate_drop", Primitives.parseDouble(cl.getOptionValue("rate_drop"), 0.d));
params.put("one_drop", Primitives.parseInt(cl.getOptionValue("one_drop"), 0));
params.put("skip_drop", Primitives.parseDouble(cl.getOptionValue("skip_drop"), 0.d));
}
/** Parameters for Linear Booster (booster=gblinear) */
if (booster.equals("gblinear")) {
params.put("lambda", Primitives.parseDouble(cl.getOptionValue("lambda"), 0.d));
params.put("lambda_bias",
Primitives.parseDouble(cl.getOptionValue("lambda_bias"), 0.d));
params.put("alpha", Primitives.parseDouble(cl.getOptionValue("alpha"), 0.d));
params.put("updater", cl.getOptionValue("updater", "shotgun"));
params.put("feature_selector", cl.getOptionValue("feature_selector", "cyclic"));
params.put("top_k", Primitives.parseInt(cl.getOptionValue("top_k"), 0));
}
/** Parameters for Tweedie Regression (objective=reg:tweedie) */
if (objective.equals("reg:tweedie")) {
params.put("tweedie_variance_power",
Primitives.parseDouble(cl.getOptionValue("tweedie_variance_power"), 1.5d));
}
/** Parameters for Poisson Regression (objective=count:poisson) */
if (objective.equals("count:poisson")) {
// max_delta_step is set to 0.7 by default in poisson regression (used to safeguard optimization)
params.put("max_delta_step",
Primitives.parseDouble(cl.getOptionValue("max_delta_step"), 0.7d));
}
/** Learning Task Parameters */
params.put("objective", objective);
params.put("base_score", Primitives.parseDouble(cl.getOptionValue("base_score"), 0.5d));
if (cl.hasOption("eval_metric")) {
params.put("eval_metric", cl.getOptionValue("eval_metric"));
}
params.put("seed", Primitives.parseLong(cl.getOptionValue("seed"), 43L));
if (cl.hasOption("num_class")) {
this.numClass = Integer.parseInt(cl.getOptionValue("num_class"));
params.put("num_class", numClass);
} else {
if (objective.startsWith("multi:")) {
throw new UDFArgumentException(
"-num_class is required for multiclass classification");
}
}
if (logger.isInfoEnabled()) {
logger.info("XGboost training hyperparameters: " + params.toString());
}
this.objectiveType = ObjectiveType.resolve(objective);
return cl;
}
@Override
public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs)
throws UDFArgumentException {
if (argOIs.length != 2 && argOIs.length != 3) {
showHelp("Invalid argment length=" + argOIs.length);
}
processOptions(argOIs);
ListObjectInspector listOI = HiveUtils.asListOI(argOIs, 0);
ObjectInspector elemOI = listOI.getListElementObjectInspector();
this.featureListOI = listOI;
if (HiveUtils.isNumberOI(elemOI)) {
this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
this.denseInput = true;
this.matrixBuilder = new DenseDMatrixBuilder(8192);
} else if (HiveUtils.isStringOI(elemOI)) {
this.featureElemOI = HiveUtils.asStringOI(elemOI);
this.denseInput = false;
this.matrixBuilder = new SparseDMatrixBuilder(8192);
} else {
throw new UDFArgumentException(
"train_xgboost takes array or array for the first argument: "
+ listOI.getTypeName());
}
this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs, 1);
this.labels = new FloatArrayList(1024);
final List fieldNames = new ArrayList<>(2);
final List fieldOIs = new ArrayList<>(2);
fieldNames.add("model_id");
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
fieldNames.add("model");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
/** To validate target range, overrides this method */
protected float processTargetValue(final float target) throws HiveException {
switch (objectiveType) {
case binary: {
if (target != -1 && target != 0 && target != 1) {
throw new UDFArgumentException(
"Invalid label value for classification: " + target);
}
return target > 0.f ? 1.f : 0.f;
}
case multiclass: {
final int clazz = (int) target;
if (clazz != target) {
throw new UDFArgumentException(
"Invalid target value for class label: " + target);
}
if (clazz < 0 || clazz >= numClass) {
throw new UDFArgumentException("target must be {0.0, ..., "
+ String.format("%.1f", (numClass - 1.0)) + "}: " + target);
}
return target;
}
default:
return target;
}
}
@Override
public void process(@Nonnull Object[] args) throws HiveException {
if (args[0] == null) {
throw new HiveException("array features was null");
}
parseFeatures(args[0], matrixBuilder);
float target = PrimitiveObjectInspectorUtils.getFloat(args[1], targetOI);
labels.add(processTargetValue(target));
}
private void parseFeatures(@Nonnull final Object argObj,
@Nonnull final DMatrixBuilder builder) {
if (denseInput) {
final int length = featureListOI.getListLength(argObj);
for (int i = 0; i < length; i++) {
Object o = featureListOI.getListElement(argObj, i);
if (o == null) {
continue;
}
float v = PrimitiveObjectInspectorUtils.getFloat(o, featureElemOI);
builder.nextColumn(i, v);
}
} else {
final int length = featureListOI.getListLength(argObj);
for (int i = 0; i < length; i++) {
Object o = featureListOI.getListElement(argObj, i);
if (o == null) {
continue;
}
String fv = o.toString();
builder.nextColumn(fv);
}
}
builder.nextRow();
}
@Override
public void close() throws HiveException {
DMatrix dmatrix = null;
Booster booster = null;
try {
dmatrix = matrixBuilder.buildMatrix(labels.toArray(true));
this.matrixBuilder = null;
this.labels = null;
final int round = OptionUtils.getInt(params, "num_round");
final int earlyStoppingRounds = OptionUtils.getInt(params, "num_early_stopping_rounds");
if (earlyStoppingRounds > 0) {
double validationRatio = OptionUtils.getDouble(params, "validation_ratio");
long seed = OptionUtils.getLong(params, "seed");
int numRows = (int) dmatrix.rowNum();
int[] rows = MathUtils.permutation(numRows);
ArrayUtils.shuffle(rows, new Random(seed));
int numTest = (int) (numRows * validationRatio);
DMatrix dtrain = null, dtest = null;
try {
dtest = dmatrix.slice(Arrays.copyOf(rows, numTest));
dtrain = dmatrix.slice(Arrays.copyOfRange(rows, numTest, rows.length));
booster = train(dtrain, dtest, round, earlyStoppingRounds, params);
} finally {
XGBoostUtils.close(dtrain);
XGBoostUtils.close(dtest);
}
} else {
booster = train(dmatrix, round, params);
}
onFinishTraining(booster);
// Output the built model
String modelId = generateUniqueModelId();
Text predModel = XGBoostUtils.serializeBooster(booster);
logger.info("model_id:" + modelId.toString() + ", size:" + predModel.getLength());
forward(new Object[] {modelId, predModel});
} catch (Throwable e) {
throw new HiveException(e);
} finally {
XGBoostUtils.close(dmatrix);
XGBoostUtils.close(booster);
}
}
@VisibleForTesting
protected void onFinishTraining(@Nonnull Booster booster) {}
@Nonnull
private static Booster train(@Nonnull final DMatrix dtrain, @Nonnegative final int round,
@Nonnull final Map params)
throws NoSuchMethodException, IllegalAccessException, InvocationTargetException,
InstantiationException, XGBoostError {
final Booster booster = XGBoostUtils.createBooster(dtrain, params);
for (int iter = 0; iter < round; iter++) {
booster.update(dtrain, iter);
}
return booster;
}
@Nonnull
private static Booster train(@Nonnull final DMatrix dtrain, @Nonnull final DMatrix dtest,
@Nonnegative final int round, @Nonnegative final int earlyStoppingRounds,
@Nonnull final Map params)
throws NoSuchMethodException, IllegalAccessException, InvocationTargetException,
InstantiationException, XGBoostError {
final Booster booster = XGBoostUtils.createBooster(dtrain, params);
final boolean maximizeEvaluationMetrics =
OptionUtils.getBoolean(params, "maximize_evaluation_metrics");
float bestScore = maximizeEvaluationMetrics ? -Float.MAX_VALUE : Float.MAX_VALUE;
int bestIteration = 0;
final float[] metricsOut = new float[1];
for (int iter = 0; iter < round; iter++) {
booster.update(dtrain, iter);
String evalInfo =
booster.evalSet(new DMatrix[] {dtest}, new String[] {"test"}, iter, metricsOut);
logger.info(evalInfo);
final float score = metricsOut[0];
if (maximizeEvaluationMetrics) {
// Update best score if the current score is better (no update when equal)
if (score > bestScore) {
bestScore = score;
bestIteration = iter;
}
} else {
if (score < bestScore) {
bestScore = score;
bestIteration = iter;
}
}
if (shouldEarlyStop(earlyStoppingRounds, iter, bestIteration)) {
logger.info(
String.format("early stopping after %d rounds away from the best iteration",
earlyStoppingRounds));
break;
}
}
return booster;
}
private static boolean shouldEarlyStop(final int earlyStoppingRounds, final int iter,
final int bestIteration) {
return iter - bestIteration >= earlyStoppingRounds;
}
@Nonnull
private static String generateUniqueModelId() {
return "xgbmodel-" + HadoopUtils.getUniqueTaskIdString();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy