hivemall.smile.tools.DecisionPathUDF Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.smile.tools;
import hivemall.UDFWithOptions;
import matrix4j.vector.DenseVector;
import matrix4j.vector.SparseVector;
import matrix4j.vector.Vector;
import hivemall.smile.classification.DecisionTree;
import hivemall.smile.classification.PredictionHandler;
import hivemall.smile.regression.RegressionTree;
import hivemall.utils.codec.Base91;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.StringUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.Text;
// @formatter:off
@Description(name = "decision_path",
value = "_FUNC_(string modelId, string model, array features [, const string options] [, optional array featureNames=null, optional array classNames=null])"
+ " - Returns a decision path for each prediction in array",
extended = "SELECT\n" +
" t.passengerid,\n" +
" decision_path(m.model_id, m.model, t.features, '-classification')\n" +
"FROM\n" +
" model_rf m\n" +
" LEFT OUTER JOIN\n" +
" test_rf t;\n" +
" | 892 | [\"2 [0.0] = 0.0\",\"0 [3.0] = 3.0\",\"1 [696.0] != 107.0\",\"7 [7.8292] <= 7.9104\",\"1 [696.0] != 828.0\",\"1 [696.0] != 391.0\",\"0 [0.961038961038961, 0.03896103896103896]\"] |\n\n" +
"-- Show 100 frequent branches\n" +
"WITH tmp as (\n" +
" SELECT\n" +
" decision_path(m.model_id, m.model, t.features, '-classification -no_verbose -no_leaf', array('pclass','name','sex','age','sibsp','parch','ticket','fare','cabin','embarked'), array('no','yes')) as path\n" +
" FROM\n" +
" model_rf m\n" +
" LEFT OUTER JOIN -- CROSS JOIN\n" +
" test_rf t\n" +
")\n" +
"select\n" +
" r.branch,\n" +
" count(1) as cnt\n" +
"from\n" +
" tmp l\n" +
" LATERAL VIEW explode(l.path) r as branch\n" +
"group by\n" +
" r.branch\n" +
"order by\n" +
" cnt desc\n" +
"limit 100;")
// @formatter:on
@UDFType(deterministic = true, stateful = false)
public final class DecisionPathUDF extends UDFWithOptions {
private StringObjectInspector modelOI;
private ListObjectInspector featureListOI;
private PrimitiveObjectInspector featureElemOI;
private boolean denseInput;
// options
private boolean classification = false;
private boolean summarize = true;
private boolean verbose = true;
private boolean noLeaf = false;
@Nullable
private String[] featureNames;
@Nullable
private String[] classNames;
@Nullable
private transient Vector featuresProbe;
@Nullable
private transient Evaluator evaluator;
@Override
protected Options getOptions() {
Options opts = new Options();
opts.addOption("c", "classification", false,
"Predict as classification [default: not enabled]");
opts.addOption("no_sumarize", "disable_summarization", false,
"Do not summarize decision paths");
opts.addOption("no_verbose", "disable_verbose_output", false,
"Disable verbose output [default: verbose]");
opts.addOption("no_leaf", "disable_leaf_output", false,
"Show leaf value [default: not enabled]");
return opts;
}
@Override
protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException {
CommandLine cl = parseOptions(optionValue);
this.classification = cl.hasOption("classification");
this.summarize = !cl.hasOption("no_sumarize");
this.verbose = !cl.hasOption("disable_verbose_output");
this.noLeaf = cl.hasOption("disable_leaf_output");
return cl;
}
@Override
public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
if (argOIs.length < 3 || argOIs.length > 6) {
showHelp("tree_predict takes 3 ~ 6 arguments");
}
this.modelOI = HiveUtils.asStringOI(argOIs[1]);
ListObjectInspector listOI = HiveUtils.asListOI(argOIs[2]);
this.featureListOI = listOI;
ObjectInspector elemOI = listOI.getListElementObjectInspector();
if (HiveUtils.isNumberOI(elemOI)) {
this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
this.denseInput = true;
} else if (HiveUtils.isStringOI(elemOI)) {
this.featureElemOI = HiveUtils.asStringOI(elemOI);
this.denseInput = false;
} else {
throw new UDFArgumentException(
"tree_predict takes array or array for the 3rd argument: "
+ listOI.getTypeName());
}
if (argOIs.length >= 4) {
ObjectInspector argOI3 = argOIs[3];
if (HiveUtils.isConstString(argOI3)) {
String opts = HiveUtils.getConstString(argOI3);
processOptions(opts);
if (argOIs.length >= 5) {
ObjectInspector argOI4 = argOIs[4];
if (HiveUtils.isConstStringListOI(argOI4)) {
this.featureNames = HiveUtils.getConstStringArray(argOI4);
if (argOIs.length >= 6) {
ObjectInspector argOI5 = argOIs[5];
if (HiveUtils.isConstStringListOI(argOI5)) {
if (!classification) {
throw new UDFArgumentException(
"classNames should not be provided for regression");
}
this.classNames = HiveUtils.getConstStringArray(argOI5);
} else {
throw new UDFArgumentException(
"decision_path expects 'const array classNames' for the 6th argument: "
+ argOI5.getTypeName());
}
}
} else {
throw new UDFArgumentException(
"decision_path expects 'const array featureNames' for the 5th argument: "
+ argOI4.getTypeName());
}
}
} else if (HiveUtils.isConstStringListOI(argOI3)) {
this.featureNames = HiveUtils.getConstStringArray(argOI3);
if (argOIs.length >= 5) {
ObjectInspector argOI4 = argOIs[4];
if (HiveUtils.isConstStringListOI(argOI4)) {
if (!classification) {
throw new UDFArgumentException(
"classNames should not be provided for regression");
}
this.classNames = HiveUtils.getConstStringArray(argOI4);
} else {
throw new UDFArgumentException(
"decision_path expects 'const array classNames' for the 5th argument: "
+ argOI4.getTypeName());
}
}
} else {
throw new UDFArgumentException(
"decision_path expects 'const array options' or 'const array featureNames' for the 4th argument: "
+ argOI3.getTypeName());
}
}
return ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector);
}
@Override
public List evaluate(@Nonnull DeferredObject[] arguments) throws HiveException {
Object arg0 = arguments[0].get();
if (arg0 == null) {
throw new HiveException("modelId should not be null");
}
// Not using string OI for backward compatibilities
String modelId = arg0.toString();
Object arg1 = arguments[1].get();
if (arg1 == null) {
return null;
}
Text model = modelOI.getPrimitiveWritableObject(arg1);
Object arg2 = arguments[2].get();
if (arg2 == null) {
throw new HiveException("features was null");
}
this.featuresProbe = parseFeatures(arg2, featuresProbe);
if (evaluator == null) {
this.evaluator = classification ? new ClassificationEvaluator(this)
: new RegressionEvaluator(this);
}
return evaluator.evaluate(modelId, model, featuresProbe);
}
@Nonnull
private Vector parseFeatures(@Nonnull final Object argObj, @Nullable Vector probe)
throws UDFArgumentException {
if (denseInput) {
final int length = featureListOI.getListLength(argObj);
if (probe == null) {
probe = new DenseVector(length);
} else if (length != probe.size()) {
probe = new DenseVector(length);
}
for (int i = 0; i < length; i++) {
final Object o = featureListOI.getListElement(argObj, i);
if (o == null) {
probe.set(i, 0.d);
} else {
double v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI);
probe.set(i, v);
}
}
} else {
if (probe == null) {
probe = new SparseVector();
} else {
probe.clear();
}
final int length = featureListOI.getListLength(argObj);
for (int i = 0; i < length; i++) {
Object o = featureListOI.getListElement(argObj, i);
if (o == null) {
continue;
}
String col = o.toString();
final int pos = col.indexOf(':');
if (pos == 0) {
throw new UDFArgumentException("Invalid feature value representation: " + col);
}
final String feature;
final double value;
if (pos > 0) {
feature = col.substring(0, pos);
String s2 = col.substring(pos + 1);
value = Double.parseDouble(s2);
} else {
feature = col;
value = 1.d;
}
if (feature.indexOf(':') != -1) {
throw new UDFArgumentException(
"Invalid feature format `:`: " + col);
}
final int colIndex = Integer.parseInt(feature);
if (colIndex < 0) {
throw new UDFArgumentException(
"Col index MUST be greater than or equals to 0: " + colIndex);
}
probe.set(colIndex, value);
}
}
return probe;
}
@Override
public void close() throws IOException {
this.modelOI = null;
this.featureElemOI = null;
this.featureListOI = null;
this.featureNames = null;
this.classNames = null;
this.featuresProbe = null;
this.evaluator = null;
}
@Override
public String getDisplayString(String[] children) {
return "decision_path(" + StringUtils.join(children, ',') + ")";
}
interface Evaluator {
@Nonnull
List evaluate(@Nonnull String modelId, @Nonnull Text model,
@Nonnull Vector features) throws HiveException;
}
static final class ClassificationEvaluator implements Evaluator {
@Nullable
private final String[] featureNames;
@Nullable
private final String[] classNames;
@Nonnull
private final List result;
@Nonnull
private final PredictionHandler handler;
@Nullable
private String prevModelId = null;
private DecisionTree.Node cNode = null;
ClassificationEvaluator(@Nonnull final DecisionPathUDF udf) {
this.featureNames = udf.featureNames;
this.classNames = udf.classNames;
final StringBuilder buf = new StringBuilder();
final ArrayList result = new ArrayList<>();
this.result = result;
if (udf.summarize) {
final LinkedHashMap map = new LinkedHashMap<>();
this.handler = new PredictionHandler() {
@Override
public void init() {
map.clear();
result.clear();
}
@Override
public void visitBranch(Operator op, int splitFeatureIndex, double splitFeature,
double splitValue) {
buf.append(resolveFeatureName(splitFeatureIndex));
if (udf.verbose) {
buf.append(" [" + splitFeature + "] ");
} else {
buf.append(' ');
}
buf.append(op);
if (op == Operator.EQ || op == Operator.NE) {
buf.append(' ');
buf.append(splitValue);
}
String key = buf.toString();
map.put(key, splitValue);
StringUtils.clear(buf);
}
@Override
public void visitLeaf(int output, double[] posteriori) {
for (Map.Entry e : map.entrySet()) {
final String key = e.getKey();
if (key.indexOf('<') == -1 && key.indexOf('>') == -1) {
result.add(key);
} else {
double value = e.getValue().doubleValue();
result.add(key + ' ' + value);
}
}
if (udf.noLeaf) {
return;
}
if (udf.verbose) {
buf.append(resolveClassName(output));
buf.append(' ');
buf.append(Arrays.toString(posteriori));
result.add(buf.toString());
StringUtils.clear(buf);
} else {
result.add(resolveClassName(output));
}
}
@SuppressWarnings("unchecked")
@Override
public ArrayList getResult() {
return result;
}
};
} else {
this.handler = new PredictionHandler() {
@Override
public void init() {
result.clear();
}
@Override
public void visitBranch(Operator op, int splitFeatureIndex, double splitFeature,
double splitValue) {
buf.append(resolveFeatureName(splitFeatureIndex));
if (udf.verbose) {
buf.append(" [" + splitFeature + "] ");
} else {
buf.append(' ');
}
buf.append(op);
buf.append(' ');
buf.append(splitValue);
result.add(buf.toString());
StringUtils.clear(buf);
}
@Override
public void visitLeaf(int output, double[] posteriori) {
if (udf.noLeaf) {
return;
}
if (udf.verbose) {
buf.append(resolveClassName(output));
buf.append(' ');
buf.append(Arrays.toString(posteriori));
result.add(buf.toString());
StringUtils.clear(buf);
} else {
result.add(resolveClassName(output));
}
}
@SuppressWarnings("unchecked")
@Override
public ArrayList getResult() {
return result;
}
};
}
}
@Nonnull
private String resolveFeatureName(final int splitFeatureIndex) {
if (featureNames == null) {
return Integer.toString(splitFeatureIndex);
} else {
return featureNames[splitFeatureIndex];
}
}
@Nonnull
private String resolveClassName(final int classLabel) {
if (classNames == null) {
return Integer.toString(classLabel);
} else {
return classNames[classLabel];
}
}
@Nonnull
public List evaluate(@Nonnull final String modelId, @Nonnull final Text script,
@Nonnull final Vector features) throws HiveException {
if (!modelId.equals(prevModelId)) {
this.prevModelId = modelId;
int length = script.getLength();
byte[] b = script.getBytes();
b = Base91.decode(b, 0, length);
this.cNode = DecisionTree.deserialize(b, b.length, true);
}
Preconditions.checkNotNull(cNode);
handler.init();
cNode.predict(features, handler);
return handler.getResult();
}
}
static final class RegressionEvaluator implements Evaluator {
@Nullable
private final String[] featureNames;
@Nonnull
private final List result;
@Nonnull
private final PredictionHandler handler;
@Nullable
private String prevModelId = null;
private RegressionTree.Node rNode = null;
RegressionEvaluator(@Nonnull final DecisionPathUDF udf) {
this.featureNames = udf.featureNames;
final StringBuilder buf = new StringBuilder();
final ArrayList result = new ArrayList<>();
this.result = result;
if (udf.summarize) {
final LinkedHashMap map = new LinkedHashMap<>();
this.handler = new PredictionHandler() {
@Override
public void init() {
map.clear();
result.clear();
}
@Override
public void visitBranch(Operator op, int splitFeatureIndex, double splitFeature,
double splitValue) {
buf.append(resolveFeatureName(splitFeatureIndex));
if (udf.verbose) {
buf.append(" [" + splitFeature + "] ");
} else {
buf.append(' ');
}
buf.append(op);
if (op == Operator.EQ || op == Operator.NE) {
buf.append(' ');
buf.append(splitValue);
}
String key = buf.toString();
map.put(key, splitValue);
StringUtils.clear(buf);
}
@Override
public void visitLeaf(double output) {
for (Map.Entry e : map.entrySet()) {
final String key = e.getKey();
if (key.indexOf('<') == -1 && key.indexOf('>') == -1) {
result.add(key);
} else {
double value = e.getValue().doubleValue();
result.add(key + ' ' + value);
}
}
if (udf.noLeaf) {
return;
}
result.add(Double.toString(output));
}
@SuppressWarnings("unchecked")
@Override
public ArrayList getResult() {
return result;
}
};
} else {
this.handler = new PredictionHandler() {
@Override
public void init() {
result.clear();
}
@Override
public void visitBranch(Operator op, int splitFeatureIndex, double splitFeature,
double splitValue) {
buf.append(resolveFeatureName(splitFeatureIndex));
if (udf.verbose) {
buf.append(" [" + splitFeature + "] ");
}
buf.append(op);
buf.append(' ');
buf.append(splitValue);
result.add(buf.toString());
StringUtils.clear(buf);
}
@Override
public void visitLeaf(double output) {
if (udf.noLeaf) {
return;
}
result.add(Double.toString(output));
}
@SuppressWarnings("unchecked")
@Override
public ArrayList getResult() {
return result;
}
};
}
}
@Nonnull
private String resolveFeatureName(final int splitFeatureIndex) {
if (featureNames == null) {
return Integer.toString(splitFeatureIndex);
} else {
return featureNames[splitFeatureIndex];
}
}
@Nonnull
public List evaluate(@Nonnull final String modelId, @Nonnull final Text script,
@Nonnull final Vector features) throws HiveException {
if (!modelId.equals(prevModelId)) {
this.prevModelId = modelId;
int length = script.getLength();
byte[] b = script.getBytes();
b = Base91.decode(b, 0, length);
this.rNode = RegressionTree.deserialize(b, b.length, true);
}
Preconditions.checkNotNull(rNode);
handler.init();
rNode.predict(features, handler);
return handler.getResult();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy