Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModel Maven / Gradle / Ivy
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.Accountable;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Numbers;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ShapPath;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;
import org.elasticsearch.xpack.core.ml.job.config.Operator;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
import static org.apache.lucene.util.RamUsageEstimator.sizeOf;
import static org.apache.lucene.util.RamUsageEstimator.sizeOfCollection;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree.CLASSIFICATION_LABELS;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree.FEATURE_NAMES;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree.TARGET_TYPE;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree.TREE_STRUCTURE;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.DECISION_TYPE;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.DEFAULT_LEFT;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.LEAF_VALUE;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.LEFT_CHILD;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.NUMBER_SAMPLES;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.RIGHT_CHILD;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.SPLIT_FEATURE;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.THRESHOLD;
public class TreeInferenceModel implements InferenceModel {
private static final Logger LOGGER = LogManager.getLogger(TreeInferenceModel.class);
public static final long SHALLOW_SIZE = shallowSizeOfInstance(TreeInferenceModel.class);
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(
"tree_inference_model",
true,
a -> new TreeInferenceModel(
(List)a[0],
(List)a[1],
a[2] == null ? null : TargetType.fromString((String)a[2]),
(List)a[3]));
static {
PARSER.declareStringArray(constructorArg(), FEATURE_NAMES);
PARSER.declareObjectArray(constructorArg(), NodeBuilder.PARSER::apply, TREE_STRUCTURE);
PARSER.declareString(optionalConstructorArg(), TARGET_TYPE);
PARSER.declareStringArray(optionalConstructorArg(), CLASSIFICATION_LABELS);
}
public static TreeInferenceModel fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}
private final Node[] nodes;
private String[] featureNames;
private final TargetType targetType;
private List classificationLabels;
private final double highOrderCategory;
private final int maxDepth;
private final int leafSize;
private volatile boolean preparedForInference = false;
TreeInferenceModel(List featureNames,
List nodes,
@Nullable TargetType targetType,
List classificationLabels) {
this.featureNames = ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES).toArray(new String[0]);
if(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE).size() == 0) {
throw new IllegalArgumentException("[tree_structure] must not be empty");
}
this.nodes = nodes.stream().map(NodeBuilder::build).toArray(Node[]::new);
this.targetType = targetType == null ? TargetType.REGRESSION : targetType;
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
this.highOrderCategory = maxLeafValue();
int leafSize = 1;
for (Node node : this.nodes) {
if (node instanceof LeafNode) {
leafSize = ((LeafNode)node).leafValue.length;
break;
}
}
this.leafSize = leafSize;
this.maxDepth = getDepth(this.nodes, 0, 0);
}
@Override
public String[] getFeatureNames() {
return featureNames;
}
@Override
public TargetType targetType() {
return targetType;
}
@Override
public InferenceResults infer(Map fields, InferenceConfig config, Map featureDecoderMap) {
return innerInfer(InferenceModel.extractFeatures(featureNames, fields), config, featureDecoderMap);
}
@Override
public InferenceResults infer(double[] features, InferenceConfig config) {
return innerInfer(features, config, Collections.emptyMap());
}
private InferenceResults innerInfer(double[] features, InferenceConfig config, Map featureDecoderMap) {
if (config.isTargetTypeSupported(targetType) == false) {
throw ExceptionsHelper.badRequestException(
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
}
if (preparedForInference == false) {
throw ExceptionsHelper.serverError("model is not prepared for inference");
}
double[][] featureImportance = config.requestingImportance() ?
featureImportance(features) :
new double[0][];
return buildResult(getLeaf(features), featureImportance, featureDecoderMap, config);
}
private InferenceResults buildResult(double[] value,
double[][] featureImportance,
Map featureDecoderMap,
InferenceConfig config) {
assert value != null && value.length > 0;
// Indicates that the config is useless and the caller just wants the raw value
if (config instanceof NullInferenceConfig) {
return new RawInferenceResults(value, featureImportance);
}
Map decodedFeatureImportance = config.requestingImportance() ?
decodeFeatureImportances(featureDecoderMap,
IntStream.range(0, featureImportance.length)
.boxed()
.collect(Collectors.toMap(i -> featureNames[i], i -> featureImportance[i]))) :
Collections.emptyMap();
switch (targetType) {
case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config;
Tuple> topClasses = InferenceHelpers.topClasses(
classificationProbability(value),
classificationLabels,
null,
classificationConfig.getNumTopClasses(),
classificationConfig.getPredictionFieldType());
final InferenceHelpers.TopClassificationValue classificationValue = topClasses.v1();
return new ClassificationInferenceResults(classificationValue.getValue(),
classificationLabel(classificationValue.getValue(), classificationLabels),
topClasses.v2(),
InferenceHelpers.transformFeatureImportanceClassification(decodedFeatureImportance,
classificationLabels,
classificationConfig.getPredictionFieldType()),
config,
classificationValue.getProbability(),
classificationValue.getScore());
case REGRESSION:
return new RegressionInferenceResults(value[0],
config,
InferenceHelpers.transformFeatureImportanceRegression(decodedFeatureImportance));
default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
}
}
private double[] classificationProbability(double[] inferenceValue) {
// Multi-value leaves, indicates that the leaves contain an array of values.
// The index of which corresponds to classification values
if (inferenceValue.length > 1) {
return Statistics.softMax(inferenceValue);
}
// If we are classification, we should assume that the inference return value is whole.
assert inferenceValue[0] == Math.rint(inferenceValue[0]);
double maxCategory = this.highOrderCategory;
// If we are classification, we should assume that the largest leaf value is whole.
assert maxCategory == Math.rint(maxCategory);
double[] list = Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)
.stream()
.mapToDouble(Double::doubleValue)
.toArray();
list[Double.valueOf(inferenceValue[0]).intValue()] = 1.0;
return list;
}
private double[] getLeaf(double[] features) {
Node node = nodes[0];
while(node.isLeaf() == false) {
node = nodes[node.compare(features)];
}
return ((LeafNode)node).leafValue;
}
public double[][] featureImportance(double[] fieldValues) {
double[][] featureImportance = new double[fieldValues.length][leafSize];
for (int i = 0; i < fieldValues.length; i++) {
featureImportance[i] = new double[leafSize];
}
int arrSize = ((this.maxDepth + 1) * (this.maxDepth + 2))/2;
ShapPath.PathElement[] elements = new ShapPath.PathElement[arrSize];
for (int i = 0; i < arrSize; i++) {
elements[i] = new ShapPath.PathElement();
}
double[] scale = new double[arrSize];
ShapPath initialPath = new ShapPath(elements, scale);
shapRecursive(fieldValues, initialPath, 0, 1.0, 1.0, -1, featureImportance, 0);
return featureImportance;
}
/**
* Note, this is a port from https://github.com/elastic/ml-cpp/blob/master/lib/maths/CTreeShapFeatureImportance.cc
*
* If improvements in performance or accuracy have been found, it is probably best that the changes are implemented on the native
* side first and then ported to the Java side.
*/
private void shapRecursive(double[] processedFeatures,
ShapPath parentSplitPath,
int nodeIndex,
double parentFractionZero,
double parentFractionOne,
int parentFeatureIndex,
double[][] featureImportance,
int nextIndex) {
ShapPath splitPath = new ShapPath(parentSplitPath, nextIndex);
Node currNode = nodes[nodeIndex];
nextIndex = splitPath.extend(parentFractionZero, parentFractionOne, parentFeatureIndex, nextIndex);
if (currNode.isLeaf()) {
double[] leafValue = ((LeafNode)currNode).leafValue;
for (int i = 1; i < nextIndex; ++i) {
int inputColumnIndex = splitPath.featureIndex(i);
double scaled = splitPath.sumUnwoundPath(i, nextIndex) * (splitPath.fractionOnes(i) - splitPath.fractionZeros(i));
for (int j = 0; j < leafValue.length; j++) {
featureImportance[inputColumnIndex][j] += scaled * leafValue[j];
}
}
} else {
InnerNode innerNode = (InnerNode)currNode;
int hotIndex = currNode.compare(processedFeatures);
int coldIndex = hotIndex == innerNode.leftChild ? innerNode.rightChild : innerNode.leftChild;
double incomingFractionZero = 1.0;
double incomingFractionOne = 1.0;
int splitFeature = innerNode.splitFeature;
int pathIndex = splitPath.findFeatureIndex(splitFeature, nextIndex);
if (pathIndex > -1) {
incomingFractionZero = splitPath.fractionZeros(pathIndex);
incomingFractionOne = splitPath.fractionOnes(pathIndex);
nextIndex = splitPath.unwind(pathIndex, nextIndex);
}
double hotFractionZero = nodes[hotIndex].getNumberSamples() / (double)currNode.getNumberSamples();
double coldFractionZero = nodes[coldIndex].getNumberSamples() / (double)currNode.getNumberSamples();
shapRecursive(processedFeatures, splitPath,
hotIndex, incomingFractionZero * hotFractionZero,
incomingFractionOne, splitFeature, featureImportance, nextIndex);
shapRecursive(processedFeatures, splitPath,
coldIndex, incomingFractionZero * coldFractionZero,
0.0, splitFeature, featureImportance, nextIndex);
}
}
@Override
public boolean supportsFeatureImportance() {
return true;
}
@Override
public String getName() {
return "tree";
}
@Override
public void rewriteFeatureIndices(Map newFeatureIndexMapping) {
LOGGER.debug(() -> new ParameterizedMessage("rewriting features {}", newFeatureIndexMapping));
if (preparedForInference) {
return;
}
preparedForInference = true;
if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) {
return;
}
for (Node node : nodes) {
if (node.isLeaf()) {
continue;
}
InnerNode treeNode = (InnerNode)node;
Integer newSplitFeatureIndex = newFeatureIndexMapping.get(featureNames[treeNode.splitFeature]);
if (newSplitFeatureIndex == null) {
throw new IllegalArgumentException("[tree] failed to optimize for inference");
}
treeNode.splitFeature = newSplitFeatureIndex;
}
this.featureNames = new String[0];
// Since we are not top level, we no longer need local classification labels
this.classificationLabels = null;
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size += sizeOfCollection(classificationLabels);
size += sizeOf(featureNames);
size += sizeOf(nodes);
return size;
}
private double maxLeafValue() {
if (targetType != TargetType.CLASSIFICATION) {
return Double.NaN;
}
double max = 0.0;
for (Node node : this.nodes) {
if (node instanceof LeafNode) {
LeafNode leafNode = (LeafNode) node;
if (leafNode.leafValue.length > 1) {
return leafNode.leafValue.length;
} else {
max = Math.max(leafNode.leafValue[0], max);
}
}
}
return max;
}
public Node[] getNodes() {
return nodes;
}
@Override
public String toString() {
return "TreeInferenceModel{" +
"nodes=" + Arrays.toString(nodes) +
", featureNames=" + Arrays.toString(featureNames) +
", targetType=" + targetType +
", classificationLabels=" + classificationLabels +
", highOrderCategory=" + highOrderCategory +
", maxDepth=" + maxDepth +
", leafSize=" + leafSize +
", preparedForInference=" + preparedForInference +
'}';
}
private static int getDepth(Node[] nodes, int nodeIndex, int depth) {
Node node = nodes[nodeIndex];
if (node instanceof LeafNode) {
return 0;
}
InnerNode innerNode = (InnerNode)node;
int depthLeft = getDepth(nodes, innerNode.leftChild, depth + 1);
int depthRight = getDepth(nodes, innerNode.rightChild, depth + 1);
return Math.max(depthLeft, depthRight) + 1;
}
static class NodeBuilder {
private static final ObjectParser PARSER = new ObjectParser<>(
"tree_inference_model_node",
true,
NodeBuilder::new);
static {
PARSER.declareDouble(NodeBuilder::setThreshold, THRESHOLD);
PARSER.declareField(NodeBuilder::setOperator,
p -> Operator.fromString(p.text()),
DECISION_TYPE,
ObjectParser.ValueType.STRING);
PARSER.declareInt(NodeBuilder::setLeftChild, LEFT_CHILD);
PARSER.declareInt(NodeBuilder::setRightChild, RIGHT_CHILD);
PARSER.declareBoolean(NodeBuilder::setDefaultLeft, DEFAULT_LEFT);
PARSER.declareInt(NodeBuilder::setSplitFeature, SPLIT_FEATURE);
PARSER.declareDoubleArray(NodeBuilder::setLeafValue, LEAF_VALUE);
PARSER.declareLong(NodeBuilder::setNumberSamples, NUMBER_SAMPLES);
}
private Operator operator = Operator.LTE;
private double threshold = Double.NaN;
private int splitFeature = -1;
private boolean defaultLeft = false;
private int leftChild = -1;
private int rightChild = -1;
private long numberSamples;
private double[] leafValue = new double[0];
public NodeBuilder setOperator(Operator operator) {
this.operator = operator;
return this;
}
public NodeBuilder setThreshold(double threshold) {
this.threshold = threshold;
return this;
}
public NodeBuilder setSplitFeature(int splitFeature) {
this.splitFeature = splitFeature;
return this;
}
public NodeBuilder setDefaultLeft(boolean defaultLeft) {
this.defaultLeft = defaultLeft;
return this;
}
public NodeBuilder setLeftChild(int leftChild) {
this.leftChild = leftChild;
return this;
}
public NodeBuilder setRightChild(int rightChild) {
this.rightChild = rightChild;
return this;
}
public NodeBuilder setNumberSamples(long numberSamples) {
this.numberSamples = numberSamples;
return this;
}
private NodeBuilder setLeafValue(List leafValue) {
return setLeafValue(leafValue.stream().mapToDouble(Double::doubleValue).toArray());
}
public NodeBuilder setLeafValue(double[] leafValue) {
this.leafValue = leafValue;
return this;
}
Node build() {
if (this.leftChild < 0) {
return new LeafNode(leafValue, numberSamples);
}
return new InnerNode(operator,
threshold,
splitFeature,
defaultLeft,
leftChild,
rightChild,
numberSamples);
}
}
public abstract static class Node implements Accountable {
int compare(double[] features) {
throw new IllegalArgumentException("cannot call compare against a leaf node.");
}
abstract long getNumberSamples();
public boolean isLeaf() {
return this instanceof LeafNode;
}
}
public static class InnerNode extends Node {
public static final long SHALLOW_SIZE = shallowSizeOfInstance(InnerNode.class);
private final Operator operator;
private final double threshold;
// Allowed to be adjusted for inference optimization
private int splitFeature;
private final boolean defaultLeft;
private final int leftChild;
private final int rightChild;
private final long numberSamples;
InnerNode(Operator operator,
double threshold,
int splitFeature,
boolean defaultLeft,
int leftChild,
int rightChild,
long numberSamples) {
this.operator = operator;
this.threshold = threshold;
this.splitFeature = splitFeature;
this.defaultLeft = defaultLeft;
this.leftChild = leftChild;
this.rightChild = rightChild;
this.numberSamples = numberSamples;
}
@Override
public int compare(double[] features) {
double feature = features[splitFeature];
if (isMissing(feature)) {
return defaultLeft ? leftChild : rightChild;
}
return operator.test(feature, threshold) ? leftChild : rightChild;
}
@Override
long getNumberSamples() {
return numberSamples;
}
private static boolean isMissing(double feature) {
return Numbers.isValidDouble(feature) == false;
}
@Override
public long ramBytesUsed() {
return SHALLOW_SIZE;
}
@Override
public String toString() {
return "InnerNode{" +
"operator=" + operator +
", threshold=" + threshold +
", splitFeature=" + splitFeature +
", defaultLeft=" + defaultLeft +
", leftChild=" + leftChild +
", rightChild=" + rightChild +
", numberSamples=" + numberSamples +
'}';
}
}
public static class LeafNode extends Node {
public static final long SHALLOW_SIZE = shallowSizeOfInstance(LeafNode.class);
private final double[] leafValue;
private final long numberSamples;
LeafNode(double[] leafValue, long numberSamples) {
this.leafValue = leafValue;
this.numberSamples = numberSamples;
}
@Override
public long ramBytesUsed() {
return SHALLOW_SIZE + sizeOf(leafValue);
}
@Override
long getNumberSamples() {
return numberSamples;
}
public double[] getLeafValue() {
return leafValue;
}
@Override
public String toString() {
return "LeafNode{" +
"leafValue=" + Arrays.toString(leafValue) +
", numberSamples=" + numberSamples +
'}';
}
}
}