All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.smile.tools.TreeExportUDF Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.smile.tools;

import hivemall.UDFWithOptions;
import hivemall.smile.classification.DecisionTree;
import hivemall.smile.regression.RegressionTree;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.utils.codec.Base91;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.mutable.MutableInt;

import java.util.Arrays;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.Text;

@Description(name = "tree_export",
        value = "_FUNC_(string model, const string options, optional array featureNames=null, optional array classNames=null)"
                + " - exports a Decision Tree model as javascript/dot]")
@UDFType(deterministic = true, stateful = false)
public final class TreeExportUDF extends UDFWithOptions {

    private transient Evaluator evaluator;

    private transient StringObjectInspector modelOI;
    @Nullable
    private transient ListObjectInspector featureNamesOI;
    @Nullable
    private transient ListObjectInspector classNamesOI;

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("t", "type", true,
            "Type of output [default: js, javascript/js, graphviz/dot");
        opts.addOption("r", "regression", false, "Is regression tree or not");
        opts.addOption("output_name", "outputName", true, "output name [default: predicted]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(@Nonnull String opts) throws UDFArgumentException {
        CommandLine cl = parseOptions(opts);

        OutputType outputType = OutputType.resolve(cl.getOptionValue("type"));
        boolean regression = cl.hasOption("regression");
        String outputName = cl.getOptionValue("output_name", "predicted");
        this.evaluator = new Evaluator(outputType, outputName, regression);

        return cl;
    }

    @Override
    public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        final int argLen = argOIs.length;
        if (argLen < 2 || argLen > 4) {
            showHelp("tree_export UDF takes 2~4 arguments: " + argLen);
        }

        this.modelOI = HiveUtils.asStringOI(argOIs, 0);

        String options = HiveUtils.getConstString(argOIs, 1);
        processOptions(options);

        if (argLen >= 3) {
            this.featureNamesOI = HiveUtils.asListOI(argOIs, 2);
            if (!HiveUtils.isStringOI(featureNamesOI.getListElementObjectInspector())) {
                throw new UDFArgumentException("_FUNC_ expected array for featureNames: "
                        + featureNamesOI.getTypeName());
            }
            if (argLen == 4) {
                this.classNamesOI = HiveUtils.asListOI(argOIs, 3);
                if (!HiveUtils.isStringOI(classNamesOI.getListElementObjectInspector())) {
                    throw new UDFArgumentException("_FUNC_ expected array for classNames: "
                            + classNamesOI.getTypeName());
                }
            }
        }

        return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
    }

    @Override
    public Object evaluate(DeferredObject[] arguments) throws HiveException {
        Object arg0 = arguments[0].get();
        if (arg0 == null) {
            return null;
        }
        Text model = modelOI.getPrimitiveWritableObject(arg0);

        String[] featureNames = null, classNames = null;
        if (arguments.length >= 3) {
            featureNames = HiveUtils.asStringArray(arguments[2], featureNamesOI);
            if (arguments.length >= 4) {
                classNames = HiveUtils.asStringArray(arguments[3], classNamesOI);
            }
        }

        try {
            return evaluator.export(model, featureNames, classNames);
        } catch (HiveException he) {
            throw he;
        } catch (Throwable e) {
            throw new HiveException(e);
        }
    }

    @Override
    public String getDisplayString(String[] children) {
        return "tree_export(" + Arrays.toString(children) + ")";
    }

    public enum OutputType {
        javascript, graphviz;

        @Nonnull
        public static OutputType resolve(@Nonnull String name) throws UDFArgumentException {
            if ("js".equalsIgnoreCase(name) || "javascript".equalsIgnoreCase(name)) {
                return javascript;
            } else if ("dot".equalsIgnoreCase(name) || "graphviz".equalsIgnoreCase(name)
                    || "graphvis".equalsIgnoreCase(name)) { // "graphvis" for backward compatibility (HIVEMALL-192)
                return graphviz;
            } else {
                throw new UDFArgumentException(
                    "Please provide a valid `-type` option from [javascript, graphviz]: " + name);
            }
        }
    }

    public static class Evaluator {

        @Nonnull
        private final OutputType outputType;
        @Nonnull
        private final String outputName;
        private final boolean regression;

        public Evaluator(@Nonnull OutputType outputType, @Nonnull String outputName,
                boolean regression) {
            this.outputType = outputType;
            this.outputName = outputName;
            this.regression = regression;
        }

        @Nonnull
        public Text export(@Nonnull Text model, @Nullable String[] featureNames,
                @Nullable String[] classNames) throws HiveException {
            int length = model.getLength();
            byte[] b = model.getBytes();
            b = Base91.decode(b, 0, length);

            final String exported;
            if (regression) {
                exported = exportRegressor(b, featureNames);
            } else {
                exported = exportClassifier(b, featureNames, classNames);
            }
            return new Text(exported);
        }

        @Nonnull
        private String exportClassifier(@Nonnull byte[] b, @Nullable String[] featureNames,
                @Nullable String[] classNames) throws HiveException {
            final DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true);

            final StringBuilder buf = new StringBuilder(8192);
            switch (outputType) {
                case javascript: {
                    node.exportJavascript(buf, featureNames, classNames, 0);
                    break;
                }
                case graphviz: {
                    buf.append(
                        "digraph Tree {\n node [shape=box, style=\"filled, rounded\", color=\"black\", fontname=helvetica];\n edge [fontname=helvetica];\n");
                    double[] colorBrew = (classNames == null) ? null
                            : SmileExtUtils.getColorBrew(classNames.length);
                    node.exportGraphviz(buf, featureNames, classNames, outputName, colorBrew,
                        new MutableInt(0), 0);
                    buf.append("}");
                    break;
                }
                default:
                    throw new HiveException("Unsupported outputType: " + outputType);
            }
            return buf.toString();
        }

        @Nonnull
        private String exportRegressor(@Nonnull byte[] b, @Nullable String[] featureNames)
                throws HiveException {
            final RegressionTree.Node node = RegressionTree.deserialize(b, b.length, true);

            final StringBuilder buf = new StringBuilder(8192);
            switch (outputType) {
                case javascript: {
                    node.exportJavascript(buf, featureNames, 0);
                    break;
                }
                case graphviz: {
                    buf.append(
                        "digraph Tree {\n node [shape=box, style=\"filled, rounded\", color=\"black\", fontname=helvetica];\n edge [fontname=helvetica];\n");
                    node.exportGraphviz(buf, featureNames, outputName, new MutableInt(0), 0);
                    buf.append("}");
                    break;
                }
                default:
                    throw new HiveException("Unsupported outputType: " + outputType);
            }
            return buf.toString();
        }

    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy