
hivemall.smile.tools.RandomForestEnsembleUDAF Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.smile.tools;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Counter;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.SizeOf;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import javax.annotation.CheckForNull;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.IntWritable;
@Description(name = "rf_ensemble",
value = "_FUNC_(int yhat [, array proba [, double model_weight=1.0]])"
+ " - Returns ensembled prediction results in probabilities>")
public final class RandomForestEnsembleUDAF extends AbstractGenericUDAFResolver {
public RandomForestEnsembleUDAF() {
super();
}
@Override
public GenericUDAFEvaluator getEvaluator(@Nonnull final TypeInfo[] typeInfo)
throws SemanticException {
switch (typeInfo.length) {
case 1: {
if (!HiveUtils.isIntegerTypeInfo(typeInfo[0])) {
throw new UDFArgumentTypeException(0, "Expected INT for yhat: " + typeInfo[0]);
}
return new RfEvaluatorV1();
}
case 3:
if (!HiveUtils.isFloatingPointTypeInfo(typeInfo[2])) {
throw new UDFArgumentTypeException(2,
"Expected DOUBLE or FLOAT for model_weight: " + typeInfo[2]);
}
/* fall through */
case 2: {// typeInfo.length == 2 || typeInfo.length == 3
if (!HiveUtils.isIntegerTypeInfo(typeInfo[0])) {
throw new UDFArgumentTypeException(0, "Expected INT for yhat: " + typeInfo[0]);
}
if (!HiveUtils.isFloatingPointListTypeInfo(typeInfo[1])) {
throw new UDFArgumentTypeException(1,
"ARRAY is expected for a posteriori: " + typeInfo[1]);
}
return new RfEvaluatorV2();
}
default:
throw new UDFArgumentLengthException(
"Expected 1~3 arguments but got " + typeInfo.length);
}
}
@Deprecated
public static final class RfEvaluatorV1 extends GenericUDAFEvaluator {
// original input
private PrimitiveObjectInspector yhatOI;
// partial aggregation
private StandardMapObjectInspector internalMergeOI;
private IntObjectInspector keyOI;
private IntObjectInspector valueOI;
public RfEvaluatorV1() {
super();
}
@Override
public ObjectInspector init(@Nonnull Mode mode, @Nonnull ObjectInspector[] argOIs)
throws HiveException {
super.init(mode, argOIs);
// initialize input
if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
this.yhatOI = HiveUtils.asIntegerOI(argOIs, 0);
} else {// from partial aggregation
this.internalMergeOI = (StandardMapObjectInspector) argOIs[0];
this.keyOI = HiveUtils.asIntOI(internalMergeOI.getMapKeyObjectInspector());
this.valueOI = HiveUtils.asIntOI(internalMergeOI.getMapValueObjectInspector());
}
// initialize output
final ObjectInspector outputOI;
if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
outputOI = ObjectInspectorFactory.getStandardMapObjectInspector(
PrimitiveObjectInspectorFactory.javaIntObjectInspector,
PrimitiveObjectInspectorFactory.javaIntObjectInspector);
} else {// terminate
List fieldNames = new ArrayList<>(3);
List fieldOIs = new ArrayList<>(3);
fieldNames.add("label");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("probability");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
fieldNames.add("probabilities");
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
outputOI = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,
fieldOIs);
}
return outputOI;
}
@Override
public RfAggregationBufferV1 getNewAggregationBuffer() throws HiveException {
RfAggregationBufferV1 buf = new RfAggregationBufferV1();
buf.reset();
return buf;
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
RfAggregationBufferV1 buf = (RfAggregationBufferV1) agg;
buf.reset();
}
@Override
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
RfAggregationBufferV1 buf = (RfAggregationBufferV1) agg;
Preconditions.checkNotNull(parameters[0]);
int yhat = PrimitiveObjectInspectorUtils.getInt(parameters[0], yhatOI);
buf.iterate(yhat);
}
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
RfAggregationBufferV1 buf = (RfAggregationBufferV1) agg;
return buf.terminatePartial();
}
@Override
public void merge(AggregationBuffer agg, Object partial) throws HiveException {
final RfAggregationBufferV1 buf = (RfAggregationBufferV1) agg;
Map, ?> partialResult = internalMergeOI.getMap(partial);
for (Map.Entry, ?> entry : partialResult.entrySet()) {
putIntoMap(entry.getKey(), entry.getValue(), buf);
}
}
private void putIntoMap(@CheckForNull Object key, @CheckForNull Object value,
@Nonnull RfAggregationBufferV1 dst) {
Preconditions.checkNotNull(key);
Preconditions.checkNotNull(value);
int k = keyOI.get(key);
int v = valueOI.get(value);
dst.merge(k, v);
}
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
RfAggregationBufferV1 buf = (RfAggregationBufferV1) agg;
return buf.terminate();
}
}
public static final class RfAggregationBufferV1 extends AbstractAggregationBuffer {
@Nonnull
private Counter partial;
public RfAggregationBufferV1() {
super();
reset();
}
void reset() {
this.partial = new Counter();
}
void iterate(final int k) {
partial.increment(k);
}
@Nonnull
Map terminatePartial() {
return partial.getMap();
}
void merge(final int k, final int v) {
partial.increment(Integer.valueOf(k), v);
}
@Nullable
Object[] terminate() {
final Map counts = partial.getMap();
final int size = counts.size();
if (size == 0) {
return null;
}
final IntArrayList keyList = new IntArrayList(size);
long totalCnt = 0L;
Integer maxKey = null;
int maxCnt = Integer.MIN_VALUE;
for (Map.Entry e : counts.entrySet()) {
Integer key = e.getKey();
keyList.add(key);
int cnt = e.getValue().intValue();
totalCnt += cnt;
if (cnt >= maxCnt) {
maxCnt = cnt;
maxKey = key;
}
}
final int[] keyArray = keyList.toArray();
Arrays.sort(keyArray);
int last = keyArray[keyArray.length - 1];
double totalCnt_d = (double) totalCnt;
final double[] probabilities = new double[Math.max(2, last + 1)];
for (int i = 0, len = probabilities.length; i < len; i++) {
final Integer cnt = counts.get(Integer.valueOf(i));
if (cnt == null) {
probabilities[i] = 0.d;
} else {
probabilities[i] = cnt.intValue() / totalCnt_d;
}
}
Object[] result = new Object[3];
result[0] = new IntWritable(maxKey);
double proba = maxCnt / totalCnt_d;
result[1] = new DoubleWritable(proba);
result[2] = WritableUtils.toWritableList(probabilities);
return result;
}
}
@SuppressWarnings("deprecation")
public static final class RfEvaluatorV2 extends GenericUDAFEvaluator {
private PrimitiveObjectInspector yhatOI;
private ListObjectInspector posterioriOI;
private PrimitiveObjectInspector posterioriElemOI;
@Nullable
private PrimitiveObjectInspector weightOI;
private StructObjectInspector internalMergeOI;
private StructField sizeField, posterioriField;
private IntObjectInspector sizeFieldOI;
private StandardListObjectInspector posterioriFieldOI;
public RfEvaluatorV2() {
super();
}
@Override
public ObjectInspector init(@Nonnull Mode mode, @Nonnull ObjectInspector[] parameters)
throws HiveException {
super.init(mode, parameters);
// initialize input
if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
this.yhatOI = HiveUtils.asIntegerOI(parameters[0]);
this.posterioriOI = HiveUtils.asListOI(parameters[1]);
this.posterioriElemOI = HiveUtils.asDoubleCompatibleOI(
posterioriOI.getListElementObjectInspector());
if (parameters.length == 3) {
this.weightOI = HiveUtils.asDoubleCompatibleOI(parameters[2]);
}
} else {// from partial aggregation
StructObjectInspector soi = (StructObjectInspector) parameters[0];
this.internalMergeOI = soi;
this.sizeField = soi.getStructFieldRef("size");
this.posterioriField = soi.getStructFieldRef("posteriori");
this.sizeFieldOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector;
this.posterioriFieldOI = ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
}
// initialize output
final ObjectInspector outputOI;
if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
List fieldNames = new ArrayList<>(3);
List fieldOIs = new ArrayList<>(3);
fieldNames.add("size");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("posteriori");
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
outputOI = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,
fieldOIs);
} else {// terminate
List fieldNames = new ArrayList<>(3);
List fieldOIs = new ArrayList<>(3);
fieldNames.add("label");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("probability");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
fieldNames.add("probabilities");
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
outputOI = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,
fieldOIs);
}
return outputOI;
}
@Override
public RfAggregationBufferV2 getNewAggregationBuffer() throws HiveException {
RfAggregationBufferV2 buf = new RfAggregationBufferV2();
reset(buf);
return buf;
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
RfAggregationBufferV2 buf = (RfAggregationBufferV2) agg;
buf.reset();
}
@Override
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
RfAggregationBufferV2 buf = (RfAggregationBufferV2) agg;
Preconditions.checkNotNull(parameters[0]);
int yhat = PrimitiveObjectInspectorUtils.getInt(parameters[0], yhatOI);
Preconditions.checkNotNull(parameters[1]);
double[] posteriori =
HiveUtils.asDoubleArray(parameters[1], posterioriOI, posterioriElemOI);
double weight = 1.0d;
if (parameters.length == 3) {
Preconditions.checkNotNull(parameters[2]);
weight = PrimitiveObjectInspectorUtils.getDouble(parameters[2], weightOI);
}
buf.iterate(yhat, weight, posteriori);
}
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
RfAggregationBufferV2 buf = (RfAggregationBufferV2) agg;
if (buf._k == -1) {
return null;
}
Object[] partial = new Object[2];
partial[0] = new IntWritable(buf._k);
partial[1] = WritableUtils.toWritableList(buf._posteriori);
return partial;
}
@Override
public void merge(AggregationBuffer agg, Object partial) throws HiveException {
if (partial == null) {
return;
}
RfAggregationBufferV2 buf = (RfAggregationBufferV2) agg;
Object o1 = internalMergeOI.getStructFieldData(partial, sizeField);
int size = sizeFieldOI.get(o1);
Object posteriori = internalMergeOI.getStructFieldData(partial, posterioriField);
// --------------------------------------------------------------
// [workaround]
// java.lang.ClassCastException: org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray
// cannot be cast to [Ljava.lang.Object;
if (posteriori instanceof LazyBinaryArray) {
posteriori = ((LazyBinaryArray) posteriori).getList();
}
buf.merge(size, posteriori, posterioriFieldOI);
}
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
RfAggregationBufferV2 buf = (RfAggregationBufferV2) agg;
if (buf._k == -1) {
return null;
}
double[] posteriori = buf._posteriori;
int label = smile.math.Math.whichMax(posteriori);
smile.math.Math.unitize1(posteriori);
double proba = posteriori[label];
Object[] result = new Object[3];
result[0] = new IntWritable(label);
result[1] = new DoubleWritable(proba);
result[2] = WritableUtils.toWritableList(posteriori);
return result;
}
}
public static final class RfAggregationBufferV2 extends AbstractAggregationBuffer {
@Nullable
private double[] _posteriori;
private int _k;
public RfAggregationBufferV2() {
super();
reset();
}
void reset() {
this._posteriori = null;
this._k = -1;
}
void iterate(final int yhat, final double weight, @Nonnull final double[] posteriori)
throws HiveException {
if (_posteriori == null) {
this._k = posteriori.length;
this._posteriori = new double[_k];
}
if (yhat >= _k) {
throw new HiveException("Predicted class " + yhat + " is out of bounds: " + _k);
}
if (posteriori.length != _k) {
throw new HiveException("Given |a posteriori| " + posteriori.length
+ " is differs from expected one: " + _k);
}
_posteriori[yhat] += (posteriori[yhat] * weight);
}
void merge(int size, @Nonnull Object posterioriObj,
@Nonnull StandardListObjectInspector posterioriOI) throws HiveException {
if (size != _k) {
if (_k == -1) {
this._k = size;
this._posteriori = new double[size];
} else {
throw new HiveException(
"Mismatch in the number of elements: _k=" + _k + ", size=" + size);
}
}
final double[] posteriori = _posteriori;
final DoubleObjectInspector doubleOI =
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
for (int i = 0, len = _k; i < len; i++) {
Object o2 = posterioriOI.getListElement(posterioriObj, i);
posteriori[i] += doubleOI.get(o2);
}
}
@Override
public int estimate() {
if (_k == -1) {
return 0;
}
return SizeOf.INT + _k * SizeOf.DOUBLE;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy