hivemall.ftvec.trans.OnehotEncodingUDAF Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.ftvec.trans;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Identifier;
import hivemall.utils.lang.Preconditions;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.CheckForNull;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.Writable;
//@formatter:off
@Description(name = "onehot_encoding",
value = "_FUNC_(PRIMITIVE feature, ...) - Compute onehot encoded label for each feature",
extended = "WITH mapping as (\n" +
" select \n" +
" m.f1, m.f2 \n" +
" from (\n" +
" select onehot_encoding(species, category) m\n" +
" from test\n" +
" ) tmp\n" +
")\n" +
"select\n" +
" array(m.f1[t.species],m.f2[t.category],feature('count',count)) as sparse_features\n" +
"from\n" +
" test t\n" +
" CROSS JOIN mapping m;\n" +
"\n" +
"[\"2\",\"8\",\"count:9\"]\n" +
"[\"5\",\"8\",\"count:10\"]\n" +
"[\"1\",\"6\",\"count:101\"]")
@UDFType(deterministic = true, stateful = true)
//@formatter:on
public final class OnehotEncodingUDAF extends AbstractGenericUDAFResolver {
public OnehotEncodingUDAF() {
super();
}
@Override
public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] argTypes)
throws SemanticException {
final int numFeatures = argTypes.length;
if (numFeatures == 0) {
throw new UDFArgumentException("_FUNC_ requires at least 1 argument");
}
for (int i = 0; i < numFeatures; i++) {
if (argTypes[i] == null) {
throw new UDFArgumentTypeException(i,
"Null type is found. Only primitive type arguments are accepted.");
}
if (argTypes[i].getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentTypeException(i,
"Only primitive type arguments are accepted but " + argTypes[i].getTypeName()
+ " was passed as parameter 1.");
}
}
return new GenericUDAFOnehotEncodingEvaluator();
}
public static final class GenericUDAFOnehotEncodingEvaluator extends GenericUDAFEvaluator {
// input OI
private PrimitiveObjectInspector[] inputElemOIs;
// merge input OI
private StructObjectInspector mergeOI;
private StructField[] fields;
private ListObjectInspector[] fieldOIs;
public GenericUDAFOnehotEncodingEvaluator() {}
@Override
public ObjectInspector init(Mode m, ObjectInspector[] argOIs) throws HiveException {
super.init(m, argOIs);
// initialize input
if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {// from original data
this.inputElemOIs = new PrimitiveObjectInspector[argOIs.length];
for (int i = 0; i < argOIs.length; i++) {
inputElemOIs[i] = HiveUtils.asPrimitiveObjectInspector(argOIs[i]);
}
} else {// from partial aggregation
Preconditions.checkArgument(argOIs.length == 1);
this.mergeOI = HiveUtils.asStructOI(argOIs[0]);
final int numFields = mergeOI.getAllStructFieldRefs().size();
this.fields = new StructField[numFields];
this.fieldOIs = new ListObjectInspector[numFields];
this.inputElemOIs = new PrimitiveObjectInspector[numFields];
for (int i = 0; i < numFields; i++) {
StructField field = mergeOI.getStructFieldRef("f" + String.valueOf(i));
fields[i] = field;
ListObjectInspector fieldOI =
HiveUtils.asListOI(field.getFieldObjectInspector());
fieldOIs[i] = fieldOI;
inputElemOIs[i] = HiveUtils.asPrimitiveObjectInspector(
fieldOI.getListElementObjectInspector());
}
}
// initialize output
final ObjectInspector outputOI;
switch (m) {
case PARTIAL1:// from original data to partial aggregation data
outputOI = internalMergeOutputOI(inputElemOIs);
break;
case PARTIAL2:// from partial aggregation data to partial aggregation data
outputOI = internalMergeOutputOI(inputElemOIs);
break;
case COMPLETE:// from original data directly to full aggregation
outputOI = terminalOutputOI(inputElemOIs);
break;
case FINAL: // from partial aggregation to full aggregation
outputOI = terminalOutputOI(inputElemOIs);
break;
default:
throw new IllegalStateException("Illegal mode: " + m);
}
return outputOI;
}
@Nonnull
private static StructObjectInspector internalMergeOutputOI(
@CheckForNull PrimitiveObjectInspector[] inputOIs) throws UDFArgumentException {
Preconditions.checkNotNull(inputOIs);
final int numOIs = inputOIs.length;
final List fieldNames = new ArrayList(numOIs);
final List fieldOIs = new ArrayList(numOIs);
for (int i = 0; i < numOIs; i++) {
fieldNames.add("f" + String.valueOf(i));
ObjectInspector elemOI = ObjectInspectorUtils.getStandardObjectInspector(
inputOIs[i], ObjectInspectorCopyOption.WRITABLE);
ListObjectInspector listOI =
ObjectInspectorFactory.getStandardListObjectInspector(elemOI);
fieldOIs.add(listOI);
}
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
@Nonnull
private static StructObjectInspector terminalOutputOI(
@CheckForNull PrimitiveObjectInspector[] inputOIs) {
Preconditions.checkNotNull(inputOIs);
Preconditions.checkArgument(inputOIs.length >= 1, inputOIs.length);
final List fieldNames = new ArrayList<>(inputOIs.length);
final List fieldOIs = new ArrayList<>(inputOIs.length);
for (int i = 0; i < inputOIs.length; i++) {
fieldNames.add("f" + String.valueOf(i + 1));
ObjectInspector keyOI = ObjectInspectorUtils.getStandardObjectInspector(inputOIs[i],
ObjectInspectorCopyOption.WRITABLE);
MapObjectInspector mapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
keyOI, PrimitiveObjectInspectorFactory.javaIntObjectInspector);
fieldOIs.add(mapOI);
}
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
@SuppressWarnings("deprecation")
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
EncodingBuffer buf = new EncodingBuffer();
reset(buf);
return buf;
}
@SuppressWarnings("deprecation")
@Override
public void reset(AggregationBuffer aggregationBuffer) throws HiveException {
EncodingBuffer buf = (EncodingBuffer) aggregationBuffer;
buf.reset();
}
@SuppressWarnings("deprecation")
@Override
public void iterate(AggregationBuffer aggregationBuffer, Object[] parameters)
throws HiveException {
Preconditions.checkNotNull(inputElemOIs);
EncodingBuffer buf = (EncodingBuffer) aggregationBuffer;
buf.iterate(parameters, inputElemOIs);
}
@SuppressWarnings("deprecation")
@Override
public Object[] terminatePartial(AggregationBuffer aggregationBuffer) throws HiveException {
EncodingBuffer buf = (EncodingBuffer) aggregationBuffer;
return buf.partial();
}
@SuppressWarnings("deprecation")
@Override
public void merge(AggregationBuffer aggregationBuffer, Object partial)
throws HiveException {
if (partial == null) {
return;
}
EncodingBuffer buf = (EncodingBuffer) aggregationBuffer;
buf.merge(partial, mergeOI, fields, fieldOIs);
}
@SuppressWarnings("deprecation")
@Override
public Object[] terminate(AggregationBuffer aggregationBuffer) throws HiveException {
EncodingBuffer buf = (EncodingBuffer) aggregationBuffer;
return buf.terminate();
}
}
public static final class EncodingBuffer extends AbstractAggregationBuffer {
@Nullable
private Identifier[] identifiers;
public EncodingBuffer() {}
void reset() {
this.identifiers = null;
}
@SuppressWarnings("unchecked")
void iterate(@Nonnull final Object[] args,
@Nonnull final PrimitiveObjectInspector[] inputOIs) throws HiveException {
Preconditions.checkArgument(args.length == inputOIs.length);
final int length = args.length;
if (identifiers == null) {
this.identifiers = new Identifier[length];
for (int i = 0; i < length; i++) {
identifiers[i] = new Identifier<>(1);
}
}
for (int i = 0; i < length; i++) {
Object arg = args[i];
if (arg == null) {
continue;
}
Writable writable = WritableUtils.copyToWritable(arg, inputOIs[i]);
identifiers[i].put(writable);
}
}
@Nullable
Object[] partial() throws HiveException {
if (identifiers == null) {
return null;
}
final int length = identifiers.length;
final Object[] partial = new Object[length];
for (int i = 0; i < length; i++) {
Set id = identifiers[i].getMap().keySet();
final List list = new ArrayList(id.size());
for (Writable e : id) {
Preconditions.checkNotNull(e);
list.add(e);
}
partial[i] = list;
}
return partial;
}
@SuppressWarnings("unchecked")
void merge(@Nonnull final Object partial, @Nonnull final StructObjectInspector mergeOI,
@Nonnull final StructField[] fields,
@Nonnull final ListObjectInspector[] fieldOIs) {
Preconditions.checkArgument(fields.length == fieldOIs.length);
final int numFields = fieldOIs.length;
if (identifiers == null) {
this.identifiers = new Identifier[numFields];
}
Preconditions.checkArgument(fields.length == identifiers.length);
for (int i = 0; i < numFields; i++) {
Identifier id = identifiers[i];
if (id == null) {
id = new Identifier<>(1);
identifiers[i] = id;
}
final Object fieldData = mergeOI.getStructFieldData(partial, fields[i]);
final ListObjectInspector fieldOI = fieldOIs[i];
for (int j = 0, size = fieldOI.getListLength(fieldData); j < size; j++) {
Object o = fieldOI.getListElement(fieldData, j);
Preconditions.checkNotNull(o);
id.valueOf((Writable) o);
}
}
}
@Nullable
Object[] terminate() {
if (identifiers == null) {
return null;
}
final Object[] ret = new Object[identifiers.length];
int max = 0;
for (int i = 0; i < identifiers.length; i++) {
final Map m = identifiers[i].getMap();
if (max != 0) {
for (Map.Entry e : m.entrySet()) {
int original = e.getValue().intValue();
e.setValue(Integer.valueOf(max + original));
}
}
ret[i] = m;
max += m.size();
}
return ret;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy