All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.ftvec.conv.ToLibSVMFormatUDF Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.ftvec.conv;

import hivemall.UDFWithOptions;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hashing.MurmurHash3;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.StringUtils;
import hivemall.utils.struct.Pair;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

// @formatter:off
@Description(name = "to_libsvm_format",
        value = "_FUNC_(array feautres [, double/integer target, const string options])"
                + " - Returns a string representation of libsvm",
                extended = "Usage:\n" + 
                        " select to_libsvm_format(array('apple:3.4','orange:2.1'))\n" + 
                        " > 6284535:3.4 8104713:2.1\n" + 
                        " select to_libsvm_format(array('apple:3.4','orange:2.1'), '-features 10')\n" + 
                        " > 3:2.1 7:3.4\n" + 
                        " select to_libsvm_format(array('7:3.4','3:2.1'), 5.0)\n" + 
                        " > 5.0 3:2.1 7:3.4")
// @formatter:on
@UDFType(deterministic = true, stateful = false)
public final class ToLibSVMFormatUDF extends UDFWithOptions {

    private ListObjectInspector _featuresOI;
    @Nullable
    private PrimitiveObjectInspector _targetOI = null;
    private int _numFeatures = MurmurHash3.DEFAULT_NUM_FEATURES;
    private StringBuilder _buf;

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("features", "num_features", true,
            "The number of features [default: 16777217 (2^24)]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(@Nonnull String optionValue) throws UDFArgumentException {
        CommandLine cl = parseOptions(optionValue);
        this._numFeatures = Primitives.parseInt(cl.getOptionValue("num_features"),
            MurmurHash3.DEFAULT_NUM_FEATURES);
        assumeTrue(_numFeatures > 0, "num_features must be greater than 0: " + _numFeatures);
        return cl;
    }

    @Override
    public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        assumeTrue(argOIs.length >= 1 || argOIs.length <= 3,
            "to_libsvm_format UDF takes 1~3 arguments");

        this._featuresOI = HiveUtils.asListOI(argOIs[0]);
        if (argOIs.length == 2) {
            ObjectInspector argOI1 = argOIs[1];
            if (HiveUtils.isNumberOI(argOI1)) {
                this._targetOI = HiveUtils.asNumberOI(argOI1);
            } else if (HiveUtils.isConstString(argOI1)) { // no target
                String opts = HiveUtils.getConstString(argOI1);
                processOptions(opts);
            } else {
                throw new UDFArgumentException(
                    "Unexpected argument type for 2nd argument: " + argOI1.getTypeName());
            }
        } else if (argOIs.length == 3) {
            this._targetOI = HiveUtils.asNumberOI(argOIs[1]);
            String opts = HiveUtils.getConstString(argOIs[2]);
            processOptions(opts);
        }

        this._buf = new StringBuilder();

        return PrimitiveObjectInspectorFactory.javaStringObjectInspector;
    }

    @Nullable
    @Override
    public String evaluate(DeferredObject[] args) throws HiveException {
        final StringBuilder buf = this._buf;
        StringUtils.clear(buf);

        Object arg0 = args[0].get();
        if (arg0 == null) {
            return null;
        }

        final int featureSize = _featuresOI.getListLength(arg0);
        List> features = new ArrayList<>(featureSize);
        for (int i = 0; i < featureSize; i++) {
            Object e = _featuresOI.getListElement(arg0, i);
            if (e == null) {
                continue;
            }
            Pair fv = parse(e.toString(), _numFeatures);
            features.add(fv);
        }
        Collections.sort(features, comparator);

        if (_targetOI != null) {
            Object arg1 = args[1].get();
            if (arg1 == null) {
                throw new HiveException("Detected NULL for the 2nd argument");
            }
            if (HiveUtils.isIntegerOI(_targetOI)) {
                int label = HiveUtils.getInt(arg1, _targetOI);
                buf.append(label);
            } else {
                double label = HiveUtils.getDouble(arg1, _targetOI);
                buf.append(label);
            }
            buf.append(' ');
        }
        for (int i = 0, size = features.size(); i < size; i++) {
            if (i != 0) {
                buf.append(' ');
            }
            Pair fv = features.get(i);
            buf.append(fv.getKey().intValue());
            buf.append(':');
            buf.append(fv.getValue().doubleValue());
        }

        return buf.toString();
    }

    @Nonnull
    public static Pair parse(@Nonnull final String fv,
            @Nonnegative final int numFeatures) throws UDFArgumentException {
        final int headPos = fv.indexOf(':');
        if (headPos == -1) {
            if (NumberUtils.isDigits(fv)) {
                final int f;
                try {
                    f = Integer.parseInt(fv);
                } catch (NumberFormatException e) {
                    throw new UDFArgumentException("Invalid feature value: " + fv);
                }
                return new Pair<>(f, 1.d);
            } else {
                return new Pair<>(mhash(fv, numFeatures), 1.d);
            }
        } else {
            final int tailPos = fv.lastIndexOf(':');
            if (headPos != tailPos) {
                throw new UDFArgumentException("Unsupported feature format: " + fv);
            }
            String f = fv.substring(0, headPos);
            String v = fv.substring(headPos + 1);
            final double d;
            try {
                d = Double.parseDouble(v);
            } catch (NumberFormatException e) {
                throw new UDFArgumentException("Invalid feature value: " + fv);
            }
            if (NumberUtils.isDigits(f)) {
                final int i;
                try {
                    i = Integer.parseInt(f);
                } catch (NumberFormatException e) {
                    throw new UDFArgumentException("Invalid feature value: " + fv);
                }
                return new Pair<>(i, d);
            } else {
                return new Pair<>(mhash(f, numFeatures), d);
            }
        }
    }

    private static int mhash(@Nonnull final String word, final int numFeatures) {
        int r = MurmurHash3.murmurhash3_x86_32(word, 0, word.length(), 0x9747b28c) % numFeatures;
        if (r < 0) {
            r += numFeatures;
        }
        return r + 1;
    }

    private static final Comparator> comparator =
            new Comparator>() {
                @Override
                public int compare(Pair l, Pair r) {
                    return l.getKey().compareTo(r.getKey());
                }
            };

    @Override
    public String getDisplayString(String[] args) {
        return "to_libsvm_format( " + StringUtils.join(args, ',') + " )";
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy