All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ml.dmlc.xgboost4j.java.flink.XGBoostModel Maven / Gradle / Ivy

The newest version!
/*
 Copyright (c) 2014-2023 by Contributors

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

 http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 */

package ml.dmlc.xgboost4j.java.flink;
import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.stream.StreamSupport;

import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.util.Collector;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;

import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;


public class XGBoostModel implements Serializable {
  private static final org.slf4j.Logger logger =
      org.slf4j.LoggerFactory.getLogger(XGBoostModel.class);

  private final Booster booster;
  private final PredictorFunction predictorFunction;


  public XGBoostModel(Booster booster) {
    this.booster = booster;
    this.predictorFunction = new PredictorFunction(booster);
  }

  /**
   * Save the model as a Hadoop filesystem file.
   *
   * @param modelPath The model path as in Hadoop path.
   */
  public void saveModelAsHadoopFile(String modelPath) throws IOException, XGBoostError {
    booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)));
  }

  public byte[] toByteArray(String format) throws XGBoostError {
    return booster.toByteArray(format);
  }

  /**
   * Save the model as a Hadoop filesystem file.
   *
   * @param modelPath The model path as in Hadoop path.
   * @param format The model format (ubj, json, deprecated)
   * @throws XGBoostError internal error
   * @throws IOException save error
   */
  public void saveModelAsHadoopFile(String modelPath, String format)
      throws IOException, XGBoostError {
    booster.saveModel(FileSystem.get(new Configuration()).create(new Path(modelPath)), format);
  }

  /**
   * predict with the given DMatrix
   *
   * @param testSet the local test set represented as DMatrix
   * @return prediction result
   */
  public float[][] predict(DMatrix testSet) throws XGBoostError {
    return booster.predict(testSet, true, 0);
  }

  /**
   * Predict given vector dataset.
   *
   * @param data The dataset to be predicted.
   * @return The prediction result.
   */
  public DataSet predict(DataSet data) {
    return data.mapPartition(predictorFunction);
  }


  private static class PredictorFunction implements MapPartitionFunction {

    private final Booster booster;

    public PredictorFunction(Booster booster) {
      this.booster = booster;
    }

    @Override
    public void mapPartition(Iterable it, Collector out) throws Exception {
      final Iterator dataIter =
          StreamSupport.stream(it.spliterator(), false)
            .map(Vector::toSparse)
            .map(PredictorFunction::fromVector)
            .iterator();

      if (dataIter.hasNext()) {
        final DMatrix data = new DMatrix(dataIter, null);
        float[][] predictions = booster.predict(data, true, 2);
        Arrays.stream(predictions).map(ArrayUtils::toObject).forEach(out::collect);
      } else {
        logger.debug("Empty partition");
      }
    }

    private static LabeledPoint fromVector(SparseVector vector) {
      final int[] index = vector.indices;
      final double[] value = vector.values;
      int size = value.length;
      final float[] values = new float[size];
      for (int i = 0; i < size; i++) {
        values[i] = (float) value[i];
      }
      return new LabeledPoint(0.0f, vector.size(), index, values);
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy