All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.ftvec.binning.FeatureBinningUDF Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.ftvec.binning;

import hivemall.annotations.VisibleForTesting;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.StringUtils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javax.annotation.Nonnull;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;

// @formatter:off
@Description(name = "feature_binning",
        value = "_FUNC_(array features, map> quantiles_map)"
                + " - returns a binned feature vector as an array\n"
                + "_FUNC_(number weight, array quantiles) - returns bin ID as int",
                extended = "WITH extracted as (\n" + 
                        "  select \n" + 
                        "    extract_feature(feature) as index,\n" + 
                        "    extract_weight(feature) as value\n" + 
                        "  from\n" + 
                        "    input l\n" + 
                        "    LATERAL VIEW explode(features) r as feature\n" + 
                        "),\n" + 
                        "mapping as (\n" + 
                        "  select\n" + 
                        "    index, \n" + 
                        "    build_bins(value, 5, true) as quantiles -- 5 bins with auto bin shrinking\n" + 
                        "  from\n" + 
                        "    extracted\n" + 
                        "  group by\n" + 
                        "    index\n" + 
                        "),\n" + 
                        "bins as (\n" + 
                        "   select \n" + 
                        "    to_map(index, quantiles) as quantiles \n" + 
                        "   from\n" + 
                        "    mapping\n" + 
                        ")\n" + 
                        "select\n" + 
                        "  l.features as original,\n" + 
                        "  feature_binning(l.features, r.quantiles) as features\n" + 
                        "from\n" + 
                        "  input l\n" + 
                        "  cross join bins r\n\n" +
                        "> [\"name#Jacob\",\"gender#Male\",\"age:20.0\"] [\"name#Jacob\",\"gender#Male\",\"age:2\"]\n" +
                        "> [\"name#Isabella\",\"gender#Female\",\"age:20.0\"]    [\"name#Isabella\",\"gender#Female\",\"age:2\"]")
// @formatter:on
@UDFType(deterministic = true, stateful = false)
public final class FeatureBinningUDF extends GenericUDF {
    private boolean multiple = true;

    private ListObjectInspector featuresOI;
    private StringObjectInspector featureOI;
    private MapObjectInspector quantilesMapOI;
    private StringObjectInspector keyOI;
    private ListObjectInspector quantilesOI;
    private PrimitiveObjectInspector quantileOI;
    private PrimitiveObjectInspector weightOI;

    @Override
    public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length != 2) {
            throw new UDFArgumentLengthException("Specify two arguments :" + argOIs.length);
        }

        if (HiveUtils.isListOI(argOIs[0]) && HiveUtils.isMapOI(argOIs[1])) {
            // feature_binning(array features, map> quantiles_map)

            if (!HiveUtils.isStringOI(
                ((ListObjectInspector) argOIs[0]).getListElementObjectInspector())) {
                throw new UDFArgumentTypeException(0,
                    "Only array type argument can be accepted but "
                            + argOIs[0].getTypeName() + " was passed as `features`");
            }
            featuresOI = HiveUtils.asListOI(argOIs[0]);
            featureOI = HiveUtils.asStringOI(featuresOI.getListElementObjectInspector());

            quantilesMapOI = HiveUtils.asMapOI(argOIs[1]);
            if (!HiveUtils.isStringOI(quantilesMapOI.getMapKeyObjectInspector())
                    || !HiveUtils.isListOI(quantilesMapOI.getMapValueObjectInspector())
                    || !HiveUtils.isNumberOI(
                        ((ListObjectInspector) quantilesMapOI.getMapValueObjectInspector()).getListElementObjectInspector())) {
                throw new UDFArgumentTypeException(1,
                    "Only map> type argument can be accepted but "
                            + argOIs[1].getTypeName() + " was passed as `quantiles_map`");
            }
            keyOI = HiveUtils.asStringOI(quantilesMapOI.getMapKeyObjectInspector());
            quantilesOI = HiveUtils.asListOI(quantilesMapOI.getMapValueObjectInspector());
            quantileOI =
                    HiveUtils.asDoubleCompatibleOI(quantilesOI.getListElementObjectInspector());

            multiple = true;

            return ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.writableStringObjectInspector);
        } else if (HiveUtils.isPrimitiveOI(argOIs[0]) && HiveUtils.isListOI(argOIs[1])) {
            // feature_binning(number weight, array quantiles)

            weightOI = HiveUtils.asDoubleCompatibleOI(argOIs[0]);

            quantilesOI = HiveUtils.asListOI(argOIs[1]);
            if (!HiveUtils.isNumberOI(quantilesOI.getListElementObjectInspector())) {
                throw new UDFArgumentTypeException(1,
                    "Only array type argument can be accepted but "
                            + argOIs[1].getTypeName() + " was passed as `quantiles`");
            }
            quantileOI =
                    HiveUtils.asDoubleCompatibleOI(quantilesOI.getListElementObjectInspector());

            multiple = false;

            return PrimitiveObjectInspectorFactory.writableIntObjectInspector;
        } else {
            throw new UDFArgumentTypeException(0,
                "Only , map>> "
                        + "or > type arguments can be accepted but <"
                        + argOIs[0].getTypeName() + ", " + argOIs[1].getTypeName()
                        + "> was passed.");
        }
    }

    private transient Map quantilesMap;
    private transient double[] quantilesArray;

    @Override
    public Object evaluate(DeferredObject[] args) throws HiveException {
        final Object arg0 = args[0].get();
        if (arg0 == null) {
            return null;
        }
        final Object arg1 = args[1].get();
        if (arg1 == null) {
            throw new UDFArgumentException(
                "The second argument (i.e., quantiles) MUST be non-null value");
        }

        if (multiple) {
            if (quantilesMap == null) {
                final Map map = quantilesMapOI.getMap(arg1);
                quantilesMap = new HashMap(map.size() * 2);
                for (Map.Entry e : map.entrySet()) {
                    String k = keyOI.getPrimitiveJavaObject(e.getKey());
                    double[] v = HiveUtils.asDoubleArray(e.getValue(), quantilesOI, quantileOI);
                    quantilesMap.put(k, v);
                }
            }

            final List features = featuresOI.getList(arg0);
            final List result = new ArrayList();
            for (Object f : features) {
                final String entry = featureOI.getPrimitiveJavaObject(f);

                final int pos = entry.indexOf(':');
                if (pos < 0) { // categorical
                    result.add(new Text(entry));
                } else { // quantitative
                    final String k = entry.substring(0, pos);
                    String v = entry.substring(pos + 1);
                    final double[] bins = quantilesMap.get(k);
                    if (bins != null) { // binning
                        v = String.valueOf(findBin(bins, Double.parseDouble(v)));
                    }
                    result.add(new Text(k + ':' + v));
                }
            }
            return result;
        } else {
            if (quantilesArray == null) {
                quantilesArray = HiveUtils.asDoubleArray(arg1, quantilesOI, quantileOI);
            }

            return new IntWritable(
                findBin(quantilesArray, PrimitiveObjectInspectorUtils.getDouble(arg0, weightOI)));
        }
    }

    @VisibleForTesting
    static int findBin(@Nonnull final double[] quantiles, final double value) throws HiveException {
        if (quantiles.length < 3) {
            throw new HiveException(
                "Length of `quantiles` should be greater than or equal to three but "
                        + quantiles.length + ".");
        }

        final int pos = Arrays.binarySearch(quantiles, value);
        return (pos < 0) ? ~pos - 1 : (pos == 0) ? 0 : pos - 1;
    }

    @Override
    public String getDisplayString(String[] children) {
        return "feature_binning(" + StringUtils.join(children, ',') + ')';
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy