hivemall.evaluation.FMeasureUDAF Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.evaluation;
import hivemall.UDAFEvaluatorWithOptions;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Primitives;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType;
import org.apache.hadoop.hive.ql.util.JavaDataModel;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;
@Description(name = "fmeasure",
value = "_FUNC_(array|int|boolean actual, array|int| boolean predicted [, const string options])"
+ " - Return a F-measure (f1score is the special with beta=1.0)")
public final class FMeasureUDAF extends AbstractGenericUDAFResolver {
@Override
public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo)
throws SemanticException {
if (typeInfo.length != 2 && typeInfo.length != 3) {
throw new UDFArgumentTypeException(typeInfo.length - 1,
"_FUNC_ takes two or three arguments");
}
boolean isArg1ListOrIntOrBoolean =
HiveUtils.isListTypeInfo(typeInfo[0]) || HiveUtils.isIntegerTypeInfo(typeInfo[0])
|| HiveUtils.isBooleanTypeInfo(typeInfo[0]);
if (!isArg1ListOrIntOrBoolean) {
throw new UDFArgumentTypeException(0,
"The first argument `array/int/boolean actual` is invalid form: " + typeInfo[0]);
}
boolean isArg2ListOrIntOrBoolean =
HiveUtils.isListTypeInfo(typeInfo[1]) || HiveUtils.isIntegerTypeInfo(typeInfo[1])
|| HiveUtils.isBooleanTypeInfo(typeInfo[1]);
if (!isArg2ListOrIntOrBoolean) {
throw new UDFArgumentTypeException(1,
"The second argument `array/int/boolean predicted` is invalid form: "
+ typeInfo[1]);
}
if (!typeInfo[0].equals(typeInfo[1])) {
throw new UDFArgumentTypeException(1,
"The first argument `actual`'s type is " + typeInfo[0]
+ ", but the second argument `predicted`'s type is not match: "
+ typeInfo[1]);
}
return new Evaluator();
}
public static class Evaluator extends UDAFEvaluatorWithOptions {
private ObjectInspector actualOI;
private ObjectInspector predictedOI;
private StructObjectInspector internalMergeOI;
private StructField tpField;
private StructField totalActualField;
private StructField totalPredictedField;
private StructField betaOptionField;
private StructField averageOptionFiled;
private double beta;
private String average;
public Evaluator() {}
@Override
protected Options getOptions() {
Options opts = new Options();
opts.addOption("beta", true, "The weight of precision [default: 1.]");
opts.addOption("average", true, "The way of average calculation [default: micro]");
return opts;
}
@Override
protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
CommandLine cl = null;
double beta = 1.0d;
String average = "micro";
if (argOIs.length >= 3) {
String rawArgs = HiveUtils.getConstString(argOIs[2]);
cl = parseOptions(rawArgs);
beta = Primitives.parseDouble(cl.getOptionValue("beta"), beta);
if (beta <= 0.d) {
throw new UDFArgumentException(
"The third argument `double beta` must be greater than 0.0: " + beta);
}
average = cl.getOptionValue("average", average);
if (average.equals("macro")) {
throw new UDFArgumentException("\"-average macro\" is not supported");
}
if (!(average.equals("binary") || average.equals("micro"))) {
throw new UDFArgumentException(
"The third argument `String average` must be one of the {binary, micro, macro}: "
+ average);
}
}
this.beta = beta;
this.average = average;
return cl;
}
@Override
public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException {
assert (parameters.length == 2 || parameters.length == 3) : parameters.length;
super.init(mode, parameters);
// initialize input
if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
processOptions(parameters);
this.actualOI = parameters[0];
this.predictedOI = parameters[1];
} else {// from partial aggregation
StructObjectInspector soi = (StructObjectInspector) parameters[0];
this.internalMergeOI = soi;
this.tpField = soi.getStructFieldRef("tp");
this.totalActualField = soi.getStructFieldRef("totalActual");
this.totalPredictedField = soi.getStructFieldRef("totalPredicted");
this.betaOptionField = soi.getStructFieldRef("beta");
this.averageOptionFiled = soi.getStructFieldRef("average");
}
// initialize output
final ObjectInspector outputOI;
if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
outputOI = internalMergeOI();
} else {// terminate
outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
}
return outputOI;
}
@Nonnull
private static StructObjectInspector internalMergeOI() {
List fieldNames = new ArrayList<>();
List fieldOIs = new ArrayList<>();
fieldNames.add("tp");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
fieldNames.add("totalActual");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
fieldNames.add("totalPredicted");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
fieldNames.add("beta");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
fieldNames.add("average");
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
@Override
public FMeasureAggregationBuffer getNewAggregationBuffer() throws HiveException {
FMeasureAggregationBuffer myAggr = new FMeasureAggregationBuffer();
reset(myAggr);
return myAggr;
}
@Override
public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg)
throws HiveException {
FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) agg;
myAggr.reset();
myAggr.setOptions(beta, average);
}
@Override
public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg,
Object[] parameters) throws HiveException {
FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) agg;
boolean isList = HiveUtils.isListOI(actualOI) && HiveUtils.isListOI(predictedOI);
final List> actual;
final List> predicted;
if (isList) {// array case
if ("binary".equals(average)) {
throw new UDFArgumentException(
"\"-average binary\" is not supported when `predict` is array");
}
actual = ((ListObjectInspector) actualOI).getList(parameters[0]);
predicted = ((ListObjectInspector) predictedOI).getList(parameters[1]);
} else {//binary case
if (HiveUtils.isBooleanOI(actualOI)) { // boolean case
actual = Arrays.asList(
asIntLabel(parameters[0], (BooleanObjectInspector) actualOI));
predicted = Arrays.asList(
asIntLabel(parameters[1], (BooleanObjectInspector) predictedOI));
} else { // integer case
final int actualLabel =
asIntLabel(parameters[0], HiveUtils.asIntegerOI(actualOI));
if (actualLabel == 0 && "binary".equals(average)) {
actual = Collections.emptyList();
} else {
actual = Arrays.asList(actualLabel);
}
final int predictedLabel =
asIntLabel(parameters[1], HiveUtils.asIntegerOI(predictedOI));
if (predictedLabel == 0 && "binary".equals(average)) {
predicted = Collections.emptyList();
} else {
predicted = Arrays.asList(predictedLabel);
}
}
}
myAggr.iterate(actual, predicted);
}
private static int asIntLabel(@Nonnull final Object o,
@Nonnull final BooleanObjectInspector booleanOI) {
if (booleanOI.get(o)) {
return 1;
} else {
return 0;
}
}
private static int asIntLabel(@Nonnull final Object o,
@Nonnull final PrimitiveObjectInspector intOI) throws UDFArgumentException {
final int value = PrimitiveObjectInspectorUtils.getInt(o, intOI);
switch (value) {
case 1:
return 1;
case 0:
case -1:
return 0;
default:
throw new UDFArgumentException("Int label must be 1, 0 or -1: " + value);
}
}
@Override
public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg)
throws HiveException {
FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) agg;
Object[] partialResult = new Object[5];
partialResult[0] = new LongWritable(myAggr.tp);
partialResult[1] = new LongWritable(myAggr.totalActual);
partialResult[2] = new LongWritable(myAggr.totalPredicted);
partialResult[3] = new DoubleWritable(myAggr.beta);
partialResult[4] = myAggr.average;
return partialResult;
}
@Override
public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial)
throws HiveException {
if (partial == null) {
return;
}
Object tpObj = internalMergeOI.getStructFieldData(partial, tpField);
Object totalActualObj = internalMergeOI.getStructFieldData(partial, totalActualField);
Object totalPredictedObj =
internalMergeOI.getStructFieldData(partial, totalPredictedField);
Object betaObj = internalMergeOI.getStructFieldData(partial, betaOptionField);
Object averageObj = internalMergeOI.getStructFieldData(partial, averageOptionFiled);
long tp = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(tpObj);
long totalActual =
PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(totalActualObj);
long totalPredicted = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(
totalPredictedObj);
double beta =
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(betaObj);
String average =
PrimitiveObjectInspectorFactory.writableStringObjectInspector.getPrimitiveJavaObject(
averageObj);
FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) agg;
myAggr.merge(tp, totalActual, totalPredicted, beta, average);
}
@Override
public DoubleWritable terminate(@SuppressWarnings("deprecation") AggregationBuffer agg)
throws HiveException {
FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) agg;
double result = myAggr.get();
return new DoubleWritable(result);
}
}
@AggregationType(estimable = true)
public static class FMeasureAggregationBuffer extends AbstractAggregationBuffer {
long tp;
/** tp + fn */
long totalActual;
/** tp + fp */
long totalPredicted;
double beta;
String average;
public FMeasureAggregationBuffer() {
super();
}
@Override
public int estimate() {
JavaDataModel model = JavaDataModel.get();
return model.primitive2() * 4 + model.lengthFor(average);
}
void setOptions(double beta, String average) {
this.beta = beta;
this.average = average;
}
void reset() {
this.tp = 0L;
this.totalActual = 0L;
this.totalPredicted = 0L;
}
void merge(final long o_tp, final long o_actual, final long o_predicted, final double beta,
final String average) {
tp += o_tp;
totalActual += o_actual;
totalPredicted += o_predicted;
this.beta = beta;
this.average = average;
}
double get() {
final double squareBeta = beta * beta;
final double divisor;
final double numerator;
if ("micro".equals(average)) {
divisor = denom(tp, totalActual, totalPredicted, squareBeta);
numerator = (1.d + squareBeta) * tp;
} else { // binary
double precision = precision(tp, totalPredicted);
double recall = recall(tp, totalActual);
divisor = squareBeta * precision + recall;
numerator = (1.d + squareBeta) * precision * recall;
}
if (divisor > 0) {
return (numerator / divisor);
} else {
return 0.d;
}
}
private static double denom(final long tp, final long totalActual,
final long totalPredicted, double squareBeta) {
long lp = totalActual - tp;
long pl = totalPredicted - tp;
return squareBeta * (tp + lp) + tp + pl;
}
private static double precision(final long tp, final long totalPredicted) {
return (totalPredicted == 0L) ? 0.d : tp / (double) totalPredicted;
}
private static double recall(final long tp, final long totalActual) {
return (totalActual == 0L) ? 0.d : tp / (double) totalActual;
}
void iterate(@Nonnull final List> actual, @Nonnull final List> predicted) {
final int numActual = actual.size();
final int numPredicted = predicted.size();
int countTp = 0;
for (Object p : predicted) {
if (actual.contains(p)) {
countTp++;
}
}
this.tp += countTp;
this.totalActual += numActual;
this.totalPredicted += numPredicted;
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy