All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.ftvec.trans.OnehotEncodingUDAF Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.ftvec.trans;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Identifier;
import hivemall.utils.lang.Preconditions;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;

import javax.annotation.CheckForNull;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.Writable;

//@formatter:off
@Description(name = "onehot_encoding",
        value = "_FUNC_(PRIMITIVE feature, ...) - Compute onehot encoded label for each feature",
        extended = "WITH mapping as (\n" + 
                "  select \n" + 
                "    m.f1, m.f2 \n" + 
                "  from (\n" + 
                "    select onehot_encoding(species, category) m\n" + 
                "    from test\n" + 
                "  ) tmp\n" + 
                ")\n" + 
                "select\n" + 
                "  array(m.f1[t.species],m.f2[t.category],feature('count',count)) as sparse_features\n" + 
                "from\n" + 
                "  test t\n" + 
                "  CROSS JOIN mapping m;\n" + 
                "\n" + 
                "[\"2\",\"8\",\"count:9\"]\n" + 
                "[\"5\",\"8\",\"count:10\"]\n" + 
                "[\"1\",\"6\",\"count:101\"]")
@UDFType(deterministic = true, stateful = true)
//@formatter:on
public final class OnehotEncodingUDAF extends AbstractGenericUDAFResolver {

    public OnehotEncodingUDAF() {
        super();
    }

    @Override
    public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] argTypes)
            throws SemanticException {
        final int numFeatures = argTypes.length;
        if (numFeatures == 0) {
            throw new UDFArgumentException("_FUNC_ requires at least 1 argument");
        }
        for (int i = 0; i < numFeatures; i++) {
            if (argTypes[i] == null) {
                throw new UDFArgumentTypeException(i,
                    "Null type is found. Only primitive type arguments are accepted.");
            }
            if (argTypes[i].getCategory() != ObjectInspector.Category.PRIMITIVE) {
                throw new UDFArgumentTypeException(i,
                    "Only primitive type arguments are accepted but " + argTypes[i].getTypeName()
                            + " was passed as parameter 1.");
            }
        }

        return new GenericUDAFOnehotEncodingEvaluator();
    }

    public static final class GenericUDAFOnehotEncodingEvaluator extends GenericUDAFEvaluator {

        // input OI
        private PrimitiveObjectInspector[] inputElemOIs;
        // merge input OI
        private StructObjectInspector mergeOI;
        private StructField[] fields;
        private ListObjectInspector[] fieldOIs;

        public GenericUDAFOnehotEncodingEvaluator() {}

        @Override
        public ObjectInspector init(Mode m, ObjectInspector[] argOIs) throws HiveException {
            super.init(m, argOIs);

            // initialize input
            if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {// from original data
                this.inputElemOIs = new PrimitiveObjectInspector[argOIs.length];
                for (int i = 0; i < argOIs.length; i++) {
                    inputElemOIs[i] = HiveUtils.asPrimitiveObjectInspector(argOIs[i]);
                }
            } else {// from partial aggregation
                Preconditions.checkArgument(argOIs.length == 1);
                this.mergeOI = HiveUtils.asStructOI(argOIs[0]);
                final int numFields = mergeOI.getAllStructFieldRefs().size();
                this.fields = new StructField[numFields];
                this.fieldOIs = new ListObjectInspector[numFields];
                this.inputElemOIs = new PrimitiveObjectInspector[numFields];
                for (int i = 0; i < numFields; i++) {
                    StructField field = mergeOI.getStructFieldRef("f" + String.valueOf(i));
                    fields[i] = field;
                    ListObjectInspector fieldOI =
                            HiveUtils.asListOI(field.getFieldObjectInspector());
                    fieldOIs[i] = fieldOI;
                    inputElemOIs[i] = HiveUtils.asPrimitiveObjectInspector(
                        fieldOI.getListElementObjectInspector());
                }
            }

            // initialize output
            final ObjectInspector outputOI;
            switch (m) {
                case PARTIAL1:// from original data to partial aggregation data                    
                    outputOI = internalMergeOutputOI(inputElemOIs);
                    break;
                case PARTIAL2:// from partial aggregation data to partial aggregation data
                    outputOI = internalMergeOutputOI(inputElemOIs);
                    break;
                case COMPLETE:// from original data directly to full aggregation
                    outputOI = terminalOutputOI(inputElemOIs);
                    break;
                case FINAL: // from partial aggregation to full aggregation
                    outputOI = terminalOutputOI(inputElemOIs);
                    break;
                default:
                    throw new IllegalStateException("Illegal mode: " + m);
            }
            return outputOI;
        }

        @Nonnull
        private static StructObjectInspector internalMergeOutputOI(
                @CheckForNull PrimitiveObjectInspector[] inputOIs) throws UDFArgumentException {
            Preconditions.checkNotNull(inputOIs);

            final int numOIs = inputOIs.length;
            final List fieldNames = new ArrayList(numOIs);
            final List fieldOIs = new ArrayList(numOIs);
            for (int i = 0; i < numOIs; i++) {
                fieldNames.add("f" + String.valueOf(i));
                ObjectInspector elemOI = ObjectInspectorUtils.getStandardObjectInspector(
                    inputOIs[i], ObjectInspectorCopyOption.WRITABLE);
                ListObjectInspector listOI =
                        ObjectInspectorFactory.getStandardListObjectInspector(elemOI);
                fieldOIs.add(listOI);
            }
            return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
        }

        @Nonnull
        private static StructObjectInspector terminalOutputOI(
                @CheckForNull PrimitiveObjectInspector[] inputOIs) {
            Preconditions.checkNotNull(inputOIs);
            Preconditions.checkArgument(inputOIs.length >= 1, inputOIs.length);

            final List fieldNames = new ArrayList<>(inputOIs.length);
            final List fieldOIs = new ArrayList<>(inputOIs.length);
            for (int i = 0; i < inputOIs.length; i++) {
                fieldNames.add("f" + String.valueOf(i + 1));
                ObjectInspector keyOI = ObjectInspectorUtils.getStandardObjectInspector(inputOIs[i],
                    ObjectInspectorCopyOption.WRITABLE);
                MapObjectInspector mapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
                    keyOI, PrimitiveObjectInspectorFactory.javaIntObjectInspector);
                fieldOIs.add(mapOI);
            }
            return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
        }

        @SuppressWarnings("deprecation")
        @Override
        public AggregationBuffer getNewAggregationBuffer() throws HiveException {
            EncodingBuffer buf = new EncodingBuffer();
            reset(buf);
            return buf;
        }

        @SuppressWarnings("deprecation")
        @Override
        public void reset(AggregationBuffer aggregationBuffer) throws HiveException {
            EncodingBuffer buf = (EncodingBuffer) aggregationBuffer;
            buf.reset();
        }

        @SuppressWarnings("deprecation")
        @Override
        public void iterate(AggregationBuffer aggregationBuffer, Object[] parameters)
                throws HiveException {
            Preconditions.checkNotNull(inputElemOIs);

            EncodingBuffer buf = (EncodingBuffer) aggregationBuffer;
            buf.iterate(parameters, inputElemOIs);
        }

        @SuppressWarnings("deprecation")
        @Override
        public Object[] terminatePartial(AggregationBuffer aggregationBuffer) throws HiveException {
            EncodingBuffer buf = (EncodingBuffer) aggregationBuffer;
            return buf.partial();
        }

        @SuppressWarnings("deprecation")
        @Override
        public void merge(AggregationBuffer aggregationBuffer, Object partial)
                throws HiveException {
            if (partial == null) {
                return;
            }

            EncodingBuffer buf = (EncodingBuffer) aggregationBuffer;
            buf.merge(partial, mergeOI, fields, fieldOIs);
        }

        @SuppressWarnings("deprecation")
        @Override
        public Object[] terminate(AggregationBuffer aggregationBuffer) throws HiveException {
            EncodingBuffer buf = (EncodingBuffer) aggregationBuffer;
            return buf.terminate();
        }
    }

    public static final class EncodingBuffer extends AbstractAggregationBuffer {

        @Nullable
        private Identifier[] identifiers;

        public EncodingBuffer() {}

        void reset() {
            this.identifiers = null;
        }

        @SuppressWarnings("unchecked")
        void iterate(@Nonnull final Object[] args,
                @Nonnull final PrimitiveObjectInspector[] inputOIs) throws HiveException {
            Preconditions.checkArgument(args.length == inputOIs.length);

            final int length = args.length;
            if (identifiers == null) {
                this.identifiers = new Identifier[length];
                for (int i = 0; i < length; i++) {
                    identifiers[i] = new Identifier<>(1);
                }
            }

            for (int i = 0; i < length; i++) {
                Object arg = args[i];
                if (arg == null) {
                    continue;
                }
                Writable writable = WritableUtils.copyToWritable(arg, inputOIs[i]);
                identifiers[i].put(writable);
            }
        }

        @Nullable
        Object[] partial() throws HiveException {
            if (identifiers == null) {
                return null;
            }

            final int length = identifiers.length;
            final Object[] partial = new Object[length];
            for (int i = 0; i < length; i++) {
                Set id = identifiers[i].getMap().keySet();
                final List list = new ArrayList(id.size());
                for (Writable e : id) {
                    Preconditions.checkNotNull(e);
                    list.add(e);
                }
                partial[i] = list;
            }
            return partial;
        }

        @SuppressWarnings("unchecked")
        void merge(@Nonnull final Object partial, @Nonnull final StructObjectInspector mergeOI,
                @Nonnull final StructField[] fields,
                @Nonnull final ListObjectInspector[] fieldOIs) {
            Preconditions.checkArgument(fields.length == fieldOIs.length);

            final int numFields = fieldOIs.length;
            if (identifiers == null) {
                this.identifiers = new Identifier[numFields];
            }
            Preconditions.checkArgument(fields.length == identifiers.length);

            for (int i = 0; i < numFields; i++) {
                Identifier id = identifiers[i];
                if (id == null) {
                    id = new Identifier<>(1);
                    identifiers[i] = id;
                }
                final Object fieldData = mergeOI.getStructFieldData(partial, fields[i]);
                final ListObjectInspector fieldOI = fieldOIs[i];
                for (int j = 0, size = fieldOI.getListLength(fieldData); j < size; j++) {
                    Object o = fieldOI.getListElement(fieldData, j);
                    Preconditions.checkNotNull(o);
                    id.valueOf((Writable) o);
                }
            }
        }

        @Nullable
        Object[] terminate() {
            if (identifiers == null) {
                return null;
            }
            final Object[] ret = new Object[identifiers.length];
            int max = 0;
            for (int i = 0; i < identifiers.length; i++) {
                final Map m = identifiers[i].getMap();
                if (max != 0) {
                    for (Map.Entry e : m.entrySet()) {
                        int original = e.getValue().intValue();
                        e.setValue(Integer.valueOf(max + original));
                    }
                }
                ret[i] = m;
                max += m.size();
            }
            return ret;
        }

    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy