All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.genmodel.easy.EasyPredictModelWrapper Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.genmodel.easy;

import hex.ModelCategory;
import hex.genmodel.*;
import hex.genmodel.algos.deeplearning.DeeplearningMojoModel;
import hex.genmodel.algos.drf.DrfMojoModel;
import hex.genmodel.algos.glrm.GlrmMojoModel;
import hex.genmodel.algos.targetencoder.TargetEncoderMojoModel;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.algos.tree.TreeBackedMojoModel;
import hex.genmodel.algos.word2vec.WordEmbeddingModel;
import hex.genmodel.attributes.ModelAttributes;
import hex.genmodel.attributes.VariableImportances;
import hex.genmodel.attributes.parameters.FeatureContribution;
import hex.genmodel.attributes.parameters.KeyValue;
import hex.genmodel.attributes.parameters.VariableImportancesHolder;
import hex.genmodel.easy.error.VoidErrorConsumer;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.prediction.*;

import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static hex.genmodel.utils.ArrayUtils.nanArray;

/**
 * An easy-to-use prediction wrapper for generated models.  Instantiate as follows.  The following two are equivalent.
 *
 *     EasyPredictModelWrapper model = new EasyPredictModelWrapper(rawModel);
 *
 *     EasyPredictModelWrapper model = new EasyPredictModelWrapper(
 *                                         new EasyPredictModelWrapper.Config()
 *                                             .setModel(rawModel)
 *                                             .setConvertUnknownCategoricalLevelsToNa(false));
 *
 * Note that for any given model, you must use the exact one correct predict method below based on the
 * model category.
 *
 * By default, unknown categorical levels result in a thrown PredictUnknownCategoricalLevelException.
 * The API was designed with this default to make the simplest possible setup inform the user if there are concerns
 * with the data quality.
 * An alternate behavior is to automatically convert unknown categorical levels to N/A.  To do this, use
 * setConvertUnknownCategoricalLevelsToNa(true) instead.
 *
 * Detection of unknown categoricals may be observed by registering an implementation of {@link ErrorConsumer}
 * in the process of {@link Config} creation.
 * 
 * Advanced scoring features are disabled by default for performance reasons. Configuration flags
 * allow the user to output also
 *  - leaf node assignment,
 *  - GLRM reconstructed matrix,
 *  - staged probabilities,
 *  - prediction contributions (SHAP values).
 *
 * Deprecation note: Total number of unknown categorical variables is newly accessible by registering {@link hex.genmodel.easy.error.CountingErrorConsumer}.
 *
 *
 * 

* See the top-of-tree master version of this file here on github. */ public class EasyPredictModelWrapper implements Serializable { // These private members are read-only after the constructor. public final GenModel m; private final RowToRawDataConverter rowDataConverter; private final boolean useExtendedOutput; private final boolean enableLeafAssignment; private final boolean enableGLRMReconstruct; // if set true, will return the GLRM resconstructed value, A_hat=X*Y instead of just X private final boolean enableStagedProbabilities; // if set true, staged probabilities from tree agos are returned private final boolean enableContributions; // if set to true, will return prediction contributions (SHAP values) - for GBM & XGBoost private final int glrmIterNumber; // allow user to set GLRM mojo iteration number in constructing x. private final PredictContributions predictContributions; public boolean getEnableLeafAssignment() { return enableLeafAssignment; } public boolean getEnableGLRMReconstruct() { return enableGLRMReconstruct; } public boolean getEnableStagedProbabilities() { return enableStagedProbabilities; } public boolean getEnableContributions() { return enableContributions; } /** * Observer interface with methods corresponding to errors during the prediction. */ public static abstract class ErrorConsumer implements Serializable { /** * Observe transformation error for data from the predicted dataset. * * @param columnName Name of the column for which the error is raised * @param value Original value that could not be transformed properly * @param message Transformation error message */ public abstract void dataTransformError(String columnName, Object value, String message); /** * Previously unseen categorical level has been detected * * @param columnName Name of the column to which the categorical value belongs * @param value Original value * @param message Reason and/or actions taken */ public abstract void unseenCategorical(String columnName, Object value, String message); } /** * Configuration builder for instantiating a Wrapper. */ public static class Config { private GenModel model; private boolean convertUnknownCategoricalLevelsToNa = false; private boolean convertInvalidNumbersToNa = false; private boolean useExtendedOutput = false; private ErrorConsumer errorConsumer; private boolean enableLeafAssignment = false; // default to false private boolean enableGLRMReconstrut = false; private boolean enableStagedProbabilities = false; private boolean enableContributions = false; private boolean useExternalEncoding = false; private int glrmIterNumber = 100; // default set to 100 /** * Specify model object to wrap. * * @param value model * @return this config object */ public Config setModel(GenModel value) { model = value; return this; } /** * @return model object being wrapped */ public GenModel getModel() { return model; } /** * Specify how to handle unknown categorical levels. * * @param value false: throw exception; true: convert to N/A * @return this config object */ public Config setConvertUnknownCategoricalLevelsToNa(boolean value) { convertUnknownCategoricalLevelsToNa = value; return this; } public Config setEnableLeafAssignment(boolean val) throws IOException { if (val && (model==null)) throw new IOException("enableLeafAssignment cannot be set with null model. Call setModel() first."); if (val && !(model instanceof TreeBackedMojoModel)) throw new IOException("enableLeafAssignment can be set to true only with TreeBackedMojoModel," + " i.e. with GBM, DRF, Isolation forest or XGBoost."); enableLeafAssignment = val; return this; } public Config setEnableGLRMReconstrut(boolean value) throws IOException { if (value && (model==null)) throw new IOException("Cannot set enableGLRMReconstruct for a null model. Call config.setModel() first."); if (value && !(model instanceof GlrmMojoModel)) throw new IOException("enableGLRMReconstruct shall only be used with GlrmMojoModels."); enableGLRMReconstrut = value; return this; } public Config setGLRMIterNumber(int value) throws IOException { if (model==null) throw new IOException("Cannot set glrmIterNumber for a null model. Call config.setModel() first."); if (!(model instanceof GlrmMojoModel)) throw new IOException("glrmIterNumber shall only be used with GlrmMojoModels."); if (value <= 0) throw new IllegalArgumentException("GLRMIterNumber must be positive."); glrmIterNumber = value; return this; } public Config setEnableStagedProbabilities (boolean val) throws IOException { if (val && (model==null)) throw new IOException("enableStagedProbabilities cannot be set with null model. Call setModel() first."); if (val && !(model instanceof SharedTreeMojoModel)) throw new IOException("enableStagedProbabilities can be set to true only with SharedTreeMojoModel," + " i.e. with GBM or DRF."); enableStagedProbabilities = val; return this; } public boolean getEnableGLRMReconstrut() { return enableGLRMReconstrut; } public Config setEnableContributions(boolean val) throws IOException { if (val && (model==null)) throw new IOException("setEnableContributions cannot be set with null model. Call setModel() first."); if (val && !(model instanceof PredictContributionsFactory)) throw new IOException("setEnableContributions can be set to true only with DRF, GBM, or XGBoost models."); if (val && (ModelCategory.Multinomial.equals(model.getModelCategory()))) { throw new IOException("setEnableContributions is not yet supported for multinomial classification models."); } if (val && model instanceof DrfMojoModel && ((DrfMojoModel) model).isBinomialDoubleTrees()) { throw new IOException("setEnableContributions is not yet supported for model with binomial_double_trees parameter set."); } enableContributions = val; return this; } public boolean getEnableContributions() { return enableContributions; } /** * Allows to switch on/off applying categorical encoding in EasyPredictModelWrapper. * In current implementation only AUTO encoding is supported by the Wrapper, users are required to set * this flag to true if they want to use POJOs/MOJOs with other encodings than AUTO. * * This requirement will be removed in https://github.com/h2oai/h2o-3/issues/8707 * @param val if true, user needs to provide already encoded input in the RowData structure * @return self */ public Config setUseExternalEncoding(boolean val) { useExternalEncoding = val; return this; } public boolean getUseExternalEncoding() { return useExternalEncoding; } /** * @return Setting for unknown categorical levels handling */ public boolean getConvertUnknownCategoricalLevelsToNa() { return convertUnknownCategoricalLevelsToNa; } public int getGLRMIterNumber() { return glrmIterNumber; } /** * Specify the default action when a string value cannot be converted to * a number. * * @param value if true, then an N/A value will be produced, if false an * exception will be thrown. */ public Config setConvertInvalidNumbersToNa(boolean value) { convertInvalidNumbersToNa = value; return this; } public boolean getConvertInvalidNumbersToNa() { return convertInvalidNumbersToNa; } /** * Specify whether to include additional metadata in the prediction output. * This feature needs to be supported by a particular model and type of metadata * is model specific. * * @param value if true, then the Prediction result will contain extended information * about the prediction (this will be specific to a particular model). * @return this config object */ public Config setUseExtendedOutput(boolean value) { useExtendedOutput = value; return this; } public boolean getUseExtendedOutput() { return useExtendedOutput; } public boolean getEnableLeafAssignment() { return enableLeafAssignment;} public boolean getEnableStagedProbabilities() { return enableStagedProbabilities;} /** * @return An instance of ErrorConsumer used to build the {@link EasyPredictModelWrapper}. Null if there is no instance. */ public ErrorConsumer getErrorConsumer() { return errorConsumer; } /** * Specify an instance of {@link ErrorConsumer} the {@link EasyPredictModelWrapper} is going to call * whenever an error defined by the {@link ErrorConsumer} instance occurs. * * @param errorConsumer An instance of {@link ErrorConsumer} * @return This {@link Config} object */ public Config setErrorConsumer(final ErrorConsumer errorConsumer) { this.errorConsumer = errorConsumer; return this; } } /** * Create a wrapper for a generated model. * * @param config The wrapper configuration */ public EasyPredictModelWrapper(Config config) { m = config.getModel(); // Ensure an error consumer is always instantiated to avoid missing null-check errors. ErrorConsumer errorConsumer = config.getErrorConsumer() == null ? new VoidErrorConsumer() : config.getErrorConsumer(); // How to handle unknown categorical levels. useExtendedOutput = config.getUseExtendedOutput(); enableLeafAssignment = config.getEnableLeafAssignment(); enableGLRMReconstruct = config.getEnableGLRMReconstrut(); enableStagedProbabilities = config.getEnableStagedProbabilities(); enableContributions = config.getEnableContributions(); glrmIterNumber = config.getGLRMIterNumber(); if (m instanceof GlrmMojoModel) ((GlrmMojoModel)m)._iterNumber=glrmIterNumber; if (enableContributions) { if (!(m instanceof PredictContributionsFactory)) { throw new IllegalStateException("Model " + m.getClass().getName() + " cannot be used to predict contributions."); } predictContributions = ((PredictContributionsFactory) m).makeContributionsPredictor(); } else { predictContributions = null; } CategoricalEncoding categoricalEncoding = config.getUseExternalEncoding() ? CategoricalEncoding.AUTO : m.getCategoricalEncoding(); Map columnMapping = categoricalEncoding.createColumnMapping(m); Map domainMap = categoricalEncoding.createCategoricalEncoders(m, columnMapping); if (m instanceof ConverterFactoryProvidingModel) { rowDataConverter = ((ConverterFactoryProvidingModel) m).makeConverterFactory(columnMapping, domainMap, errorConsumer, config); } else { rowDataConverter = new RowToRawDataConverter(m, columnMapping, domainMap, errorConsumer, config); } } /** * Create a wrapper for a generated model. * * @param model The generated model */ public EasyPredictModelWrapper(GenModel model) { this(new Config() .setModel(model)); } /** * Make a prediction on a new data point. * * The type of prediction returned depends on the model type. * The caller needs to know what type of prediction to expect. * * This call is convenient for generically automating model deployment. * For specific applications (where the kind of model is known and doesn't change), it is recommended to call * specific prediction calls like predictBinomial() directly. * * @param data A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive. * @return The prediction. * @throws PredictException */ public AbstractPrediction predict(RowData data, ModelCategory mc) throws PredictException { switch (mc) { case AutoEncoder: return predictAutoEncoder(data); case Binomial: return predictBinomial(data); case Multinomial: return predictMultinomial(data); case Ordinal: return predictOrdinal(data); case Clustering: return predictClustering(data); case Regression: return predictRegression(data); case DimReduction: return predictDimReduction(data); case WordEmbedding: return predictWord2Vec(data); case TargetEncoder: return predictTargetEncoding(data); case AnomalyDetection: return predictAnomalyDetection(data); case KLime: return predictKLime(data); case CoxPH: return predictCoxPH(data); case BinomialUplift: return predictUpliftBinomial(data); case Unknown: throw new PredictException("Unknown model category"); default: throw new PredictException("Unhandled model category (" + m.getModelCategory() + ") in switch statement"); } } /** * Make a prediction on a new data point. * * This method has the same input as predict. The only difference is that * it returns and array instead of a prediction object. * * The meaning of the returned values can be decoded by calling getOutputNames * and if any returned values are categorical - method getOutputDomain can be * used to find mapping of indexes to categorical values for the particular column. * * @param data A new data point. Column names are case-sensitive. * @param offset Value of offset (use 0 if the model was trained without offset). * @return An array representing a prediction. * @throws PredictException if prediction cannot be made (eg.: input is invalid) */ public double[] predictRaw(RowData data, double offset) throws PredictException { return preamble(m.getModelCategory(), data, offset); } /** * See {@link #predict(RowData, ModelCategory)} */ public AbstractPrediction predict(RowData data) throws PredictException { return predict(data, m.getModelCategory()); } ErrorConsumer getErrorConsumer() { return rowDataConverter.getErrorConsumer(); } /** * Returns names of contributions for prediction results with constributions enabled. * @return array of contribution names (array has same lenght as the actual contributions, last is BiasTerm) */ public String[] getContributionNames() { if (predictContributions == null) { throw new IllegalStateException( "Contributions were not enabled using in EasyPredictModelWrapper (use setEnableContributions)."); } return predictContributions.getContributionNames(); } /** * Make a prediction on a new data point using an AutoEncoder model. * @param data A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive. * @return The prediction. * @throws PredictException */ public AutoEncoderModelPrediction predictAutoEncoder(RowData data) throws PredictException { validateModelCategory(ModelCategory.AutoEncoder); int size = m.getPredsSize(ModelCategory.AutoEncoder); double[] output = new double[size]; double[] rawData = nanArray(m.nfeatures()); rawData = fillRawData(data, rawData); output = m.score0(rawData, output); AutoEncoderModelPrediction p = new AutoEncoderModelPrediction(); p.original = expandRawData(rawData, output.length); p.reconstructed = output; p.reconstructedRowData = reconstructedToRowData(output); if (m instanceof DeeplearningMojoModel){ DeeplearningMojoModel mojoModel = ((DeeplearningMojoModel)m); p.mse = mojoModel.calculateReconstructionErrorPerRowData(p.original, p.reconstructed); } return p; } /** * Creates a 1-hot encoded representation of the input data. * @param data raw input as seen by the score0 function * @param size target size of the output array * @return 1-hot encoded data */ private double[] expandRawData(double[] data, int size) { double[] expanded = new double[size]; int pos = 0; for (int i = 0; i < data.length; i++) { if (m._domains[i] == null) { expanded[pos] = data[i]; pos++; } else { int idx = Double.isNaN(data[i]) ? m._domains[i].length : (int) data[i]; expanded[pos + idx] = 1.0; pos += m._domains[i].length + 1; } } return expanded; } /** * Converts output of AutoEncoder to a RowData structure. Categorical fields are represented by * a map of domain values -> reconstructed values, missing domain value is represented by a 'null' key * @param reconstructed raw output of AutoEncoder * @return reconstructed RowData structure */ private RowData reconstructedToRowData(double[] reconstructed) { RowData rd = new RowData(); int pos = 0; for (int i = 0; i < m.nfeatures(); i++) { Object value; if (m._domains[i] == null) { value = reconstructed[pos++]; } else { value = catValuesAsMap(m._domains[i], reconstructed, pos); pos += m._domains[i].length + 1; } rd.put(m._names[i], value); } return rd; } private static Map catValuesAsMap(String[] cats, double[] reconstructed, int offset) { Map result = new HashMap<>(cats.length + 1); for (int i = 0; i < cats.length; i++) { result.put(cats[i], reconstructed[i + offset]); } result.put(null, reconstructed[offset + cats.length]); return result; } /** * Make a prediction on a new data point using a Dimension Reduction model (PCA, GLRM) * @param data A new data point. Unknown column name is treated as a NaN. Column names are case sensitive. * @return The prediction. * @throws PredictException */ public DimReductionModelPrediction predictDimReduction(RowData data) throws PredictException { double[] preds = preamble(ModelCategory.DimReduction, data); // preds contains the x factor DimReductionModelPrediction p = new DimReductionModelPrediction(); p.dimensions = preds; if (m instanceof GlrmMojoModel && ((GlrmMojoModel) m)._archetypes_raw != null && this.enableGLRMReconstruct) // only for verion 1.10 or higher p.reconstructed = ((GlrmMojoModel) m).impute_data(preds, new double[m.nfeatures()], ((GlrmMojoModel) m)._nnums, ((GlrmMojoModel) m)._ncats, ((GlrmMojoModel) m)._permutation, ((GlrmMojoModel) m)._reverse_transform, ((GlrmMojoModel) m)._normMul, ((GlrmMojoModel) m)._normSub, ((GlrmMojoModel) m)._losses, ((GlrmMojoModel) m)._transposed, ((GlrmMojoModel) m)._archetypes_raw, ((GlrmMojoModel) m)._catOffsets, ((GlrmMojoModel) m)._numLevels); return p; } /** * Calculate an aggregated word-embedding for a given input sentence (sequence of words). * * @param sentence array of word forming a sentence * @return word-embedding for the given sentence calculated by averaging the embeddings of the input words * @throws PredictException if model is not a WordEmbedding model */ public float[] predictWord2Vec(String[] sentence) throws PredictException { final WordEmbeddingModel weModel = asWordEmbeddingModel(); final int vecSize = weModel.getVecSize(); final float[] aggregated = new float[vecSize]; final float[] current = new float[vecSize]; int embeddings = 0; for (String word : sentence) { final float[] embedding = weModel.transform0(word, current); if (embedding == null) continue; embeddings++; for (int i = 0; i < vecSize; i++) aggregated[i] += embedding[i]; } if (embeddings > 0) { for (int i = 0; i < vecSize; i++) { aggregated[i] /= (float) embeddings; } } else { Arrays.fill(aggregated, Float.NaN); } return aggregated; } /** * Lookup word embeddings for a given word (or set of words). The result is a dictionary of * words mapped to their respective embeddings. * * @param data RawData structure, every key with a String value will be translated to an embedding, * note: keys only purpose is to link the output embedding to the input word. * @return The prediction * @throws PredictException if model is not a WordEmbedding model */ public Word2VecPrediction predictWord2Vec(RowData data) throws PredictException { final WordEmbeddingModel weModel = asWordEmbeddingModel(); final int vecSize = weModel.getVecSize(); HashMap embeddings = new HashMap<>(data.size()); for (String wordKey : data.keySet()) { Object value = data.get(wordKey); if (value instanceof String) { String word = (String) value; embeddings.put(wordKey, weModel.transform0(word, new float[vecSize])); } } Word2VecPrediction p = new Word2VecPrediction(); p.wordEmbeddings = embeddings; return p; } private WordEmbeddingModel asWordEmbeddingModel() throws PredictException { validateModelCategory(ModelCategory.WordEmbedding); if (! (m instanceof WordEmbeddingModel)) throw new PredictException("Model is not of the expected type, class = " + m.getClass().getSimpleName()); return (WordEmbeddingModel) m; } /** * Make a prediction on a new data point using a Anomaly Detection model. * * @param data A new data point. Unknown column name is treated as a NaN. Column names are case sensitive. * @return The prediction. * @throws PredictException */ public AnomalyDetectionPrediction predictAnomalyDetection(RowData data) throws PredictException { double[] preds = preamble(ModelCategory.AnomalyDetection, data, 0.0); AnomalyDetectionPrediction p = new AnomalyDetectionPrediction(preds); if (enableLeafAssignment) { // only get leaf node assignment if enabled SharedTreeMojoModel.LeafNodeAssignments assignments = leafNodeAssignmentExtended(data); p.leafNodeAssignments = assignments._paths; p.leafNodeAssignmentIds = assignments._nodeIds; } if (enableStagedProbabilities) { double[] rawData = nanArray(m.nfeatures()); rawData = fillRawData(data, rawData); p.stageProbabilities = ((SharedTreeMojoModel) m).scoreStagedPredictions(rawData, preds.length); } return p; } /** * Make a prediction on a new data point using a Binomial model. * * @param data A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive. * @return The prediction. * @throws PredictException */ public BinomialModelPrediction predictBinomial(RowData data) throws PredictException { return predictBinomial(data, 0.0); } /** * Make a prediction on a new data point using a Binomial model. * * @param data A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive. * @param offset An offset for the prediction. * @return The prediction. * @throws PredictException */ public BinomialModelPrediction predictBinomial(RowData data, double offset) throws PredictException { double[] preds = preamble(ModelCategory.Binomial, data, offset); BinomialModelPrediction p = new BinomialModelPrediction(); if (enableLeafAssignment) { // only get leaf node assignment if enabled SharedTreeMojoModel.LeafNodeAssignments assignments = leafNodeAssignmentExtended(data); p.leafNodeAssignments = assignments._paths; p.leafNodeAssignmentIds = assignments._nodeIds; } double d = preds[0]; p.labelIndex = (int) d; String[] domainValues = m.getDomainValues(m.getResponseIdx()); if (domainValues == null && m.getNumResponseClasses() == 2) domainValues = new String[]{"0", "1"}; // quasibinomial p.label = domainValues[p.labelIndex]; p.classProbabilities = new double[m.getNumResponseClasses()]; System.arraycopy(preds, 1, p.classProbabilities, 0, p.classProbabilities.length); if (m.calibrateClassProbabilities(preds)) { p.calibratedClassProbabilities = new double[m.getNumResponseClasses()]; System.arraycopy(preds, 1, p.calibratedClassProbabilities, 0, p.calibratedClassProbabilities.length); } if (enableStagedProbabilities) { double[] rawData = nanArray(m.nfeatures()); rawData = fillRawData(data, rawData); p.stageProbabilities = ((SharedTreeMojoModel) m).scoreStagedPredictions(rawData, preds.length); } if (enableContributions) { p.contributions = predictContributions(data); } return p; } /** * Make a prediction on a new data point using Uplift Binomial model. * @param data A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive. * @return The prediction. * @throws PredictException */ public UpliftBinomialModelPrediction predictUpliftBinomial(RowData data) throws PredictException { double[] preds = preamble(ModelCategory.BinomialUplift, data, 0); UpliftBinomialModelPrediction p = new UpliftBinomialModelPrediction(); p.predictions = preds; return p; } /** * @deprecated Use {@link #predictTargetEncoding(RowData)} instead. */ @Deprecated public TargetEncoderPrediction transformWithTargetEncoding(RowData data) throws PredictException{ return predictTargetEncoding(data); } /** * Perform target encoding based on TargetEncoderMojoModel * @param data RowData structure with data for which we want to produce transformations. * Unknown column name is treated as a NaN. Column names are case sensitive. * @return TargetEncoderPrediction with transformations ordered in accordance with corresponding categorical columns' indices in training data * @throws PredictException */ public TargetEncoderPrediction predictTargetEncoding(RowData data) throws PredictException{ if (! (m instanceof TargetEncoderMojoModel)) throw new PredictException("Model is not of the expected type, class = " + m.getClass().getSimpleName()); TargetEncoderMojoModel tem = (TargetEncoderMojoModel) this.m; double[] preds = new double[tem.getPredsSize()]; TargetEncoderPrediction prediction = new TargetEncoderPrediction(); prediction.transformations = predict(data, 0, preds); return prediction; } @SuppressWarnings("unused") // not used in this class directly, kept for backwards compatibility public String[] leafNodeAssignment(RowData data) throws PredictException { double[] rawData = nanArray(m.nfeatures()); rawData = fillRawData(data, rawData); return ((TreeBackedMojoModel) m).getDecisionPath(rawData); } public SharedTreeMojoModel.LeafNodeAssignments leafNodeAssignmentExtended(RowData data) throws PredictException { double[] rawData = nanArray(m.nfeatures()); rawData = fillRawData(data, rawData); return ((TreeBackedMojoModel) m).getLeafNodeAssignments(rawData); } /** * Make a prediction on a new data point using a Multinomial model. * * @param data A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive. * @return The prediction. * @throws PredictException */ public MultinomialModelPrediction predictMultinomial(RowData data) throws PredictException { return predictMultinomial(data, 0D); } /** * Make a prediction on a new data point using a Multinomial model. * * @param data A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive. * @param offset Prediction offset * @return The prediction. * @throws PredictException */ public MultinomialModelPrediction predictMultinomial(RowData data, double offset) throws PredictException { double[] preds = preamble(ModelCategory.Multinomial, data, offset); MultinomialModelPrediction p = new MultinomialModelPrediction(); if (enableLeafAssignment) { // only get leaf node assignment if enabled SharedTreeMojoModel.LeafNodeAssignments assignments = leafNodeAssignmentExtended(data); p.leafNodeAssignments = assignments._paths; p.leafNodeAssignmentIds = assignments._nodeIds; } p.classProbabilities = new double[m.getNumResponseClasses()]; p.labelIndex = (int) preds[0]; String[] domainValues = m.getDomainValues(m.getResponseIdx()); p.label = domainValues[p.labelIndex]; System.arraycopy(preds, 1, p.classProbabilities, 0, p.classProbabilities.length); if (enableStagedProbabilities) { double[] rawData = nanArray(m.nfeatures()); rawData = fillRawData(data, rawData); p.stageProbabilities = ((SharedTreeMojoModel) m).scoreStagedPredictions(rawData, preds.length); } return p; } /** * Make a prediction on a new data point using a Ordinal model. * * @param data A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive. * @return The prediction. * @throws PredictException */ public OrdinalModelPrediction predictOrdinal(RowData data) throws PredictException { return predictOrdinal(data, 0D); } /** * Make a prediction on a new data point using a Ordinal model. * * @param data A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive. * @param offset Prediction offset * @return The prediction. * @throws PredictException */ public OrdinalModelPrediction predictOrdinal(RowData data, double offset) throws PredictException { double[] preds = preamble(ModelCategory.Ordinal, data, offset); OrdinalModelPrediction p = new OrdinalModelPrediction(); p.classProbabilities = new double[m.getNumResponseClasses()]; p.labelIndex = (int) preds[0]; String[] domainValues = m.getDomainValues(m.getResponseIdx()); p.label = domainValues[p.labelIndex]; System.arraycopy(preds, 1, p.classProbabilities, 0, p.classProbabilities.length); return p; } /** * Sort in descending order. */ private SortedClassProbability[] sortByDescendingClassProbability(String[] domainValues, double[] classProbabilities) { assert (classProbabilities.length == domainValues.length); SortedClassProbability[] arr = new SortedClassProbability[domainValues.length]; for (int i = 0; i < domainValues.length; i++) { arr[i] = new SortedClassProbability(); arr[i].name = domainValues[i]; arr[i].probability = classProbabilities[i]; } Arrays.sort(arr, Collections.reverseOrder()); return arr; } /** * A helper function to return an array of binomial class probabilities for a prediction in sorted order. * The returned array has the most probable class in position 0. * * @param p The prediction. * @return An array with sorted class probabilities. */ public SortedClassProbability[] sortByDescendingClassProbability(BinomialModelPrediction p) { String[] domainValues = m.getDomainValues(m.getResponseIdx()); double[] classProbabilities = p.classProbabilities; return sortByDescendingClassProbability(domainValues, classProbabilities); } /** * Make a prediction on a new data point using a Clustering model. * * @param data A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive. * @return The prediction. * @throws PredictException */ public ClusteringModelPrediction predictClustering(RowData data) throws PredictException { ClusteringModelPrediction p = new ClusteringModelPrediction(); if (useExtendedOutput && (m instanceof IClusteringModel)) { IClusteringModel cm = (IClusteringModel) m; // setup raw input double[] rawData = nanArray(m.nfeatures()); rawData = fillRawData(data, rawData); // get cluster assignment & distances final int k = cm.getNumClusters(); p.distances = new double[k]; p.cluster = cm.distances(rawData, p.distances); } else { double[] preds = preamble(ModelCategory.Clustering, data); p.cluster = (int) preds[0]; } return p; } /** * Make a prediction on a new data point using a Regression model. * * @param data A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive. * @return The prediction. * @throws PredictException */ public RegressionModelPrediction predictRegression(RowData data) throws PredictException { return predictRegression(data, 0D); } /** * Make a prediction on a new data point using a Regression model. * * @param data A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive. * @param offset Prediction offset * @return The prediction. * @throws PredictException */ public RegressionModelPrediction predictRegression(RowData data, double offset) throws PredictException { double[] preds = preamble(ModelCategory.Regression, data, offset); RegressionModelPrediction p = new RegressionModelPrediction(); if (enableLeafAssignment) { // only get leaf node assignment if enabled SharedTreeMojoModel.LeafNodeAssignments assignments = leafNodeAssignmentExtended(data); p.leafNodeAssignments = assignments._paths; p.leafNodeAssignmentIds = assignments._nodeIds; } p.value = preds[0]; if (enableStagedProbabilities) { double[] rawData = nanArray(m.nfeatures()); rawData = fillRawData(data, rawData); p.stageProbabilities = ((SharedTreeMojoModel) m).scoreStagedPredictions(rawData, preds.length); } if (enableContributions) { double[] rawData = nanArray(m.nfeatures()); rawData = fillRawData(data, rawData); p.contributions = predictContributions.calculateContributions(rawData); } return p; } /** * * @param data A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive. * @return The prediction. * @throws PredictException */ public KLimeModelPrediction predictKLime(RowData data) throws PredictException { double[] preds = preamble(ModelCategory.KLime, data); KLimeModelPrediction p = new KLimeModelPrediction(); p.value = preds[0]; p.cluster = (int) preds[1]; p.reasonCodes = new double[preds.length - 2]; System.arraycopy(preds, 2, p.reasonCodes, 0, p.reasonCodes.length); return p; } public CoxPHModelPrediction predictCoxPH(RowData data, double offset) throws PredictException { final double[] preds = preamble(ModelCategory.CoxPH, data, offset); CoxPHModelPrediction p = new CoxPHModelPrediction(); p.value = preds[0]; return p; } public CoxPHModelPrediction predictCoxPH(RowData data) throws PredictException { return predictCoxPH(data, 0); } public float[] predictContributions(RowData data) throws PredictException { double[] rawData = nanArray(m.nfeatures()); rawData = fillRawData(data, rawData); return predictContributions.calculateContributions(rawData); } /** * Calculate and sort shapley values. * * @param data A new data point. Unknown or missing column name is treated as a NaN or ignored. Column names are case sensitive. * @param topN Return only #topN highest contributions + bias. * If topN<0 then sort all SHAP values in descending order * If topN<0 && bottomN<0 then sort all SHAP values in descending order * @param bottomN Return only #bottomN lowest contributions + bias * If topN and bottomN are defined together then return array of #topN + #bottomN + bias * If bottomN<0 then sort all SHAP values in ascending order * If topN<0 && bottomN<0 then sort all SHAP values in descending order * @param compareAbs True to compare absolute values of contributions * @return Sorted FeatureContribution array of contributions of size #topN + #bottomN + bias * If topN < 0 || bottomN < 0 then all descending/ascending sorted contributions is returned. * @throws PredictException When #data cannot be properly translate to raw data. */ public FeatureContribution[] predictContributions(RowData data, int topN, int bottomN, boolean compareAbs) throws PredictException { double[] rawData = nanArray(m.nfeatures()); rawData = fillRawData(data, rawData); return predictContributions.calculateContributions(rawData, topN, bottomN, compareAbs); } /** * See {@link #varimp(int)} * return descending sorted by relative importance array of all variables in the model */ public KeyValue[] varimp() { return varimp(-1); } /** * See {@link VariableImportances#topN(int)} */ public KeyValue[] varimp(int n) { if (m instanceof MojoModel) { ModelAttributes attributes = ((MojoModel) m)._modelAttributes; if (attributes == null) { throw new IllegalStateException("Model attributes are not available. Did you load metadata from model? MojoModel.load(\"model\", true)"); } else if (attributes instanceof VariableImportancesHolder) { return ((VariableImportancesHolder) attributes).getVariableImportances().topN(n); } } throw new IllegalStateException("Model does not support variable importance"); } //---------------------------------------------------------------------- // Transparent methods passed through to GenModel. //---------------------------------------------------------------------- public GenModel getModel() { return m; } /** * Get the category (type) of model. * @return The category. */ public ModelCategory getModelCategory() { return m.getModelCategory(); } /** * Get the array of levels for the response column. * "Domain" just means list of level names for a categorical (aka factor, enum) column. * If the response column is numerical and not categorical, this will return null. * * @return The array. */ public String[] getResponseDomainValues() { return m.getDomainValues(m.getResponseIdx()); } /** * Some autoencoder thing, I'm not sure what this does. * @return CSV header for autoencoder. */ public String getHeader() { return m.getHeader(); } //---------------------------------------------------------------------- // Private methods below this line. //---------------------------------------------------------------------- private void validateModelCategory(ModelCategory c) throws PredictException { if (!m.getModelCategories().contains(c)) throw new PredictException(c + " prediction type is not supported for this model."); } // This should have been called predict(), because that's what it does protected double[] preamble(ModelCategory c, RowData data) throws PredictException { return preamble(c, data, 0.0); } protected double[] preamble(ModelCategory c, RowData data, double offset) throws PredictException { validateModelCategory(c); final int predsSize = m.getPredsSize(c); return predict(data, offset, new double[predsSize]); } protected double[] fillRawData(RowData data, double[] rawData) throws PredictException { return rowDataConverter.convert(data, rawData); } protected double[] predict(RowData data, double offset, double[] preds) throws PredictException { double[] rawData = nanArray(m.nfeatures()); rawData = fillRawData(data, rawData); if (m.requiresOffset() || offset != 0) { preds = m.score0(rawData, offset, preds); } else { preds = m.score0(rawData, preds); } return preds; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy