All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.smile.tools.DecisionPathUDF Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.smile.tools;

import hivemall.UDFWithOptions;
import matrix4j.vector.DenseVector;
import matrix4j.vector.SparseVector;
import matrix4j.vector.Vector;
import hivemall.smile.classification.DecisionTree;
import hivemall.smile.classification.PredictionHandler;
import hivemall.smile.regression.RegressionTree;
import hivemall.utils.codec.Base91;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.StringUtils;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.Text;

// @formatter:off
@Description(name = "decision_path",
        value = "_FUNC_(string modelId, string model, array features [, const string options] [, optional array featureNames=null, optional array classNames=null])"
                + " - Returns a decision path for each prediction in array",
        extended = "SELECT\n" + 
                "  t.passengerid,\n" + 
                "  decision_path(m.model_id, m.model, t.features, '-classification')\n" + 
                "FROM\n" + 
                "  model_rf m\n" + 
                "  LEFT OUTER JOIN\n" + 
                "  test_rf t;\n" +
                " | 892 | [\"2 [0.0] = 0.0\",\"0 [3.0] = 3.0\",\"1 [696.0] != 107.0\",\"7 [7.8292] <= 7.9104\",\"1 [696.0] != 828.0\",\"1 [696.0] != 391.0\",\"0 [0.961038961038961, 0.03896103896103896]\"] |\n\n" +
                "-- Show 100 frequent branches\n" +
                "WITH tmp as (\n" + 
                "  SELECT\n" + 
                "    decision_path(m.model_id, m.model, t.features, '-classification -no_verbose -no_leaf', array('pclass','name','sex','age','sibsp','parch','ticket','fare','cabin','embarked'), array('no','yes')) as path\n" + 
                "  FROM\n" + 
                "    model_rf m\n" + 
                "    LEFT OUTER JOIN -- CROSS JOIN\n" + 
                "    test_rf t\n" + 
                ")\n" + 
                "select\n" + 
                "  r.branch,\n" + 
                "  count(1) as cnt\n" + 
                "from\n" + 
                "  tmp l\n" + 
                "  LATERAL VIEW explode(l.path) r as branch\n" + 
                "group by\n" + 
                "  r.branch\n" + 
                "order by\n" + 
                "  cnt desc\n" + 
                "limit 100;")
// @formatter:on
@UDFType(deterministic = true, stateful = false)
public final class DecisionPathUDF extends UDFWithOptions {

    private StringObjectInspector modelOI;
    private ListObjectInspector featureListOI;
    private PrimitiveObjectInspector featureElemOI;
    private boolean denseInput;

    // options
    private boolean classification = false;
    private boolean summarize = true;
    private boolean verbose = true;
    private boolean noLeaf = false;

    @Nullable
    private String[] featureNames;
    @Nullable
    private String[] classNames;

    @Nullable
    private transient Vector featuresProbe;

    @Nullable
    private transient Evaluator evaluator;

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("c", "classification", false,
            "Predict as classification [default: not enabled]");
        opts.addOption("no_sumarize", "disable_summarization", false,
            "Do not summarize decision paths");
        opts.addOption("no_verbose", "disable_verbose_output", false,
            "Disable verbose output [default: verbose]");
        opts.addOption("no_leaf", "disable_leaf_output", false,
            "Show leaf value [default: not enabled]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException {
        CommandLine cl = parseOptions(optionValue);

        this.classification = cl.hasOption("classification");
        this.summarize = !cl.hasOption("no_sumarize");
        this.verbose = !cl.hasOption("disable_verbose_output");
        this.noLeaf = cl.hasOption("disable_leaf_output");

        return cl;
    }

    @Override
    public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length < 3 || argOIs.length > 6) {
            showHelp("tree_predict takes 3 ~ 6 arguments");
        }

        this.modelOI = HiveUtils.asStringOI(argOIs[1]);

        ListObjectInspector listOI = HiveUtils.asListOI(argOIs[2]);
        this.featureListOI = listOI;
        ObjectInspector elemOI = listOI.getListElementObjectInspector();
        if (HiveUtils.isNumberOI(elemOI)) {
            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
            this.denseInput = true;
        } else if (HiveUtils.isStringOI(elemOI)) {
            this.featureElemOI = HiveUtils.asStringOI(elemOI);
            this.denseInput = false;
        } else {
            throw new UDFArgumentException(
                "tree_predict takes array or array for the 3rd argument: "
                        + listOI.getTypeName());
        }

        if (argOIs.length >= 4) {
            ObjectInspector argOI3 = argOIs[3];
            if (HiveUtils.isConstString(argOI3)) {
                String opts = HiveUtils.getConstString(argOI3);
                processOptions(opts);
                if (argOIs.length >= 5) {
                    ObjectInspector argOI4 = argOIs[4];
                    if (HiveUtils.isConstStringListOI(argOI4)) {
                        this.featureNames = HiveUtils.getConstStringArray(argOI4);
                        if (argOIs.length >= 6) {
                            ObjectInspector argOI5 = argOIs[5];
                            if (HiveUtils.isConstStringListOI(argOI5)) {
                                if (!classification) {
                                    throw new UDFArgumentException(
                                        "classNames should not be provided for regression");
                                }
                                this.classNames = HiveUtils.getConstStringArray(argOI5);
                            } else {
                                throw new UDFArgumentException(
                                    "decision_path expects 'const array classNames' for the 6th argument: "
                                            + argOI5.getTypeName());
                            }
                        }
                    } else {
                        throw new UDFArgumentException(
                            "decision_path expects 'const array featureNames' for the 5th argument: "
                                    + argOI4.getTypeName());
                    }
                }
            } else if (HiveUtils.isConstStringListOI(argOI3)) {
                this.featureNames = HiveUtils.getConstStringArray(argOI3);
                if (argOIs.length >= 5) {
                    ObjectInspector argOI4 = argOIs[4];
                    if (HiveUtils.isConstStringListOI(argOI4)) {
                        if (!classification) {
                            throw new UDFArgumentException(
                                "classNames should not be provided for regression");
                        }
                        this.classNames = HiveUtils.getConstStringArray(argOI4);
                    } else {
                        throw new UDFArgumentException(
                            "decision_path expects 'const array classNames' for the 5th argument: "
                                    + argOI4.getTypeName());
                    }
                }
            } else {
                throw new UDFArgumentException(
                    "decision_path expects 'const array options' or 'const array featureNames' for the 4th argument: "
                            + argOI3.getTypeName());
            }
        }

        return ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
    }

    @Override
    public List evaluate(@Nonnull DeferredObject[] arguments) throws HiveException {
        Object arg0 = arguments[0].get();
        if (arg0 == null) {
            throw new HiveException("modelId should not be null");
        }
        // Not using string OI for backward compatibilities
        String modelId = arg0.toString();

        Object arg1 = arguments[1].get();
        if (arg1 == null) {
            return null;
        }
        Text model = modelOI.getPrimitiveWritableObject(arg1);

        Object arg2 = arguments[2].get();
        if (arg2 == null) {
            throw new HiveException("features was null");
        }
        this.featuresProbe = parseFeatures(arg2, featuresProbe);

        if (evaluator == null) {
            this.evaluator = classification ? new ClassificationEvaluator(this)
                    : new RegressionEvaluator(this);
        }
        return evaluator.evaluate(modelId, model, featuresProbe);
    }

    @Nonnull
    private Vector parseFeatures(@Nonnull final Object argObj, @Nullable Vector probe)
            throws UDFArgumentException {
        if (denseInput) {
            final int length = featureListOI.getListLength(argObj);
            if (probe == null) {
                probe = new DenseVector(length);
            } else if (length != probe.size()) {
                probe = new DenseVector(length);
            }

            for (int i = 0; i < length; i++) {
                final Object o = featureListOI.getListElement(argObj, i);
                if (o == null) {
                    probe.set(i, 0.d);
                } else {
                    double v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI);
                    probe.set(i, v);
                }
            }
        } else {
            if (probe == null) {
                probe = new SparseVector();
            } else {
                probe.clear();
            }

            final int length = featureListOI.getListLength(argObj);
            for (int i = 0; i < length; i++) {
                Object o = featureListOI.getListElement(argObj, i);
                if (o == null) {
                    continue;
                }
                String col = o.toString();

                final int pos = col.indexOf(':');
                if (pos == 0) {
                    throw new UDFArgumentException("Invalid feature value representation: " + col);
                }

                final String feature;
                final double value;
                if (pos > 0) {
                    feature = col.substring(0, pos);
                    String s2 = col.substring(pos + 1);
                    value = Double.parseDouble(s2);
                } else {
                    feature = col;
                    value = 1.d;
                }

                if (feature.indexOf(':') != -1) {
                    throw new UDFArgumentException(
                        "Invalid feature format `:`: " + col);
                }

                final int colIndex = Integer.parseInt(feature);
                if (colIndex < 0) {
                    throw new UDFArgumentException(
                        "Col index MUST be greater than or equals to 0: " + colIndex);
                }
                probe.set(colIndex, value);
            }
        }
        return probe;
    }

    @Override
    public void close() throws IOException {
        this.modelOI = null;
        this.featureElemOI = null;
        this.featureListOI = null;
        this.featureNames = null;
        this.classNames = null;
        this.featuresProbe = null;
        this.evaluator = null;
    }

    @Override
    public String getDisplayString(String[] children) {
        return "decision_path(" + StringUtils.join(children, ',') + ")";
    }

    interface Evaluator {

        @Nonnull
        List evaluate(@Nonnull String modelId, @Nonnull Text model,
                @Nonnull Vector features) throws HiveException;

    }

    static final class ClassificationEvaluator implements Evaluator {

        @Nullable
        private final String[] featureNames;
        @Nullable
        private final String[] classNames;

        @Nonnull
        private final List result;
        @Nonnull
        private final PredictionHandler handler;

        @Nullable
        private String prevModelId = null;
        private DecisionTree.Node cNode = null;

        ClassificationEvaluator(@Nonnull final DecisionPathUDF udf) {
            this.featureNames = udf.featureNames;
            this.classNames = udf.classNames;

            final StringBuilder buf = new StringBuilder();
            final ArrayList result = new ArrayList<>();
            this.result = result;

            if (udf.summarize) {
                final LinkedHashMap map = new LinkedHashMap<>();

                this.handler = new PredictionHandler() {

                    @Override
                    public void init() {
                        map.clear();
                        result.clear();
                    }

                    @Override
                    public void visitBranch(Operator op, int splitFeatureIndex, double splitFeature,
                            double splitValue) {
                        buf.append(resolveFeatureName(splitFeatureIndex));
                        if (udf.verbose) {
                            buf.append(" [" + splitFeature + "] ");
                        } else {
                            buf.append(' ');
                        }
                        buf.append(op);
                        if (op == Operator.EQ || op == Operator.NE) {
                            buf.append(' ');
                            buf.append(splitValue);
                        }
                        String key = buf.toString();
                        map.put(key, splitValue);
                        StringUtils.clear(buf);
                    }

                    @Override
                    public void visitLeaf(int output, double[] posteriori) {
                        for (Map.Entry e : map.entrySet()) {
                            final String key = e.getKey();
                            if (key.indexOf('<') == -1 && key.indexOf('>') == -1) {
                                result.add(key);
                            } else {
                                double value = e.getValue().doubleValue();
                                result.add(key + ' ' + value);
                            }
                        }
                        if (udf.noLeaf) {
                            return;
                        }

                        if (udf.verbose) {
                            buf.append(resolveClassName(output));
                            buf.append(' ');
                            buf.append(Arrays.toString(posteriori));
                            result.add(buf.toString());
                            StringUtils.clear(buf);
                        } else {
                            result.add(resolveClassName(output));
                        }
                    }

                    @SuppressWarnings("unchecked")
                    @Override
                    public ArrayList getResult() {
                        return result;
                    }

                };
            } else {
                this.handler = new PredictionHandler() {

                    @Override
                    public void init() {
                        result.clear();
                    }

                    @Override
                    public void visitBranch(Operator op, int splitFeatureIndex, double splitFeature,
                            double splitValue) {
                        buf.append(resolveFeatureName(splitFeatureIndex));
                        if (udf.verbose) {
                            buf.append(" [" + splitFeature + "] ");
                        } else {
                            buf.append(' ');
                        }
                        buf.append(op);
                        buf.append(' ');
                        buf.append(splitValue);
                        result.add(buf.toString());
                        StringUtils.clear(buf);
                    }

                    @Override
                    public void visitLeaf(int output, double[] posteriori) {
                        if (udf.noLeaf) {
                            return;
                        }

                        if (udf.verbose) {
                            buf.append(resolveClassName(output));
                            buf.append(' ');
                            buf.append(Arrays.toString(posteriori));
                            result.add(buf.toString());
                            StringUtils.clear(buf);
                        } else {
                            result.add(resolveClassName(output));
                        }
                    }

                    @SuppressWarnings("unchecked")
                    @Override
                    public ArrayList getResult() {
                        return result;
                    }

                };
            }
        }

        @Nonnull
        private String resolveFeatureName(final int splitFeatureIndex) {
            if (featureNames == null) {
                return Integer.toString(splitFeatureIndex);
            } else {
                return featureNames[splitFeatureIndex];
            }
        }

        @Nonnull
        private String resolveClassName(final int classLabel) {
            if (classNames == null) {
                return Integer.toString(classLabel);
            } else {
                return classNames[classLabel];
            }
        }

        @Nonnull
        public List evaluate(@Nonnull final String modelId, @Nonnull final Text script,
                @Nonnull final Vector features) throws HiveException {
            if (!modelId.equals(prevModelId)) {
                this.prevModelId = modelId;
                int length = script.getLength();
                byte[] b = script.getBytes();
                b = Base91.decode(b, 0, length);
                this.cNode = DecisionTree.deserialize(b, b.length, true);
            }
            Preconditions.checkNotNull(cNode);

            handler.init();
            cNode.predict(features, handler);
            return handler.getResult();
        }

    }

    static final class RegressionEvaluator implements Evaluator {

        @Nullable
        private final String[] featureNames;

        @Nonnull
        private final List result;
        @Nonnull
        private final PredictionHandler handler;

        @Nullable
        private String prevModelId = null;
        private RegressionTree.Node rNode = null;

        RegressionEvaluator(@Nonnull final DecisionPathUDF udf) {
            this.featureNames = udf.featureNames;

            final StringBuilder buf = new StringBuilder();
            final ArrayList result = new ArrayList<>();
            this.result = result;

            if (udf.summarize) {
                final LinkedHashMap map = new LinkedHashMap<>();

                this.handler = new PredictionHandler() {

                    @Override
                    public void init() {
                        map.clear();
                        result.clear();
                    }

                    @Override
                    public void visitBranch(Operator op, int splitFeatureIndex, double splitFeature,
                            double splitValue) {
                        buf.append(resolveFeatureName(splitFeatureIndex));
                        if (udf.verbose) {
                            buf.append(" [" + splitFeature + "] ");
                        } else {
                            buf.append(' ');
                        }
                        buf.append(op);
                        if (op == Operator.EQ || op == Operator.NE) {
                            buf.append(' ');
                            buf.append(splitValue);
                        }
                        String key = buf.toString();
                        map.put(key, splitValue);
                        StringUtils.clear(buf);
                    }

                    @Override
                    public void visitLeaf(double output) {
                        for (Map.Entry e : map.entrySet()) {
                            final String key = e.getKey();
                            if (key.indexOf('<') == -1 && key.indexOf('>') == -1) {
                                result.add(key);
                            } else {
                                double value = e.getValue().doubleValue();
                                result.add(key + ' ' + value);
                            }
                        }
                        if (udf.noLeaf) {
                            return;
                        }

                        result.add(Double.toString(output));
                    }

                    @SuppressWarnings("unchecked")
                    @Override
                    public ArrayList getResult() {
                        return result;
                    }

                };
            } else {
                this.handler = new PredictionHandler() {

                    @Override
                    public void init() {
                        result.clear();
                    }

                    @Override
                    public void visitBranch(Operator op, int splitFeatureIndex, double splitFeature,
                            double splitValue) {
                        buf.append(resolveFeatureName(splitFeatureIndex));
                        if (udf.verbose) {
                            buf.append(" [" + splitFeature + "] ");
                        }
                        buf.append(op);
                        buf.append(' ');
                        buf.append(splitValue);
                        result.add(buf.toString());
                        StringUtils.clear(buf);
                    }

                    @Override
                    public void visitLeaf(double output) {
                        if (udf.noLeaf) {
                            return;
                        }

                        result.add(Double.toString(output));
                    }

                    @SuppressWarnings("unchecked")
                    @Override
                    public ArrayList getResult() {
                        return result;
                    }

                };
            }
        }

        @Nonnull
        private String resolveFeatureName(final int splitFeatureIndex) {
            if (featureNames == null) {
                return Integer.toString(splitFeatureIndex);
            } else {
                return featureNames[splitFeatureIndex];
            }
        }

        @Nonnull
        public List evaluate(@Nonnull final String modelId, @Nonnull final Text script,
                @Nonnull final Vector features) throws HiveException {
            if (!modelId.equals(prevModelId)) {
                this.prevModelId = modelId;
                int length = script.getLength();
                byte[] b = script.getBytes();
                b = Base91.decode(b, 0, length);
                this.rNode = RegressionTree.deserialize(b, b.length, true);
            }
            Preconditions.checkNotNull(rNode);

            handler.init();
            rNode.predict(features, handler);
            return handler.getResult();
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy