hivemall.ftvec.selection.SignalNoiseRatioUDAF Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.ftvec.selection;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.SizeOf;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
@Description(name = "snr", value = "_FUNC_(array features, array one-hot class label)"
+ " - Returns Signal Noise Ratio for each feature as array")
public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver {
@Override
public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info)
throws SemanticException {
final ObjectInspector[] OIs = info.getParameterObjectInspectors();
if (OIs.length != 2) {
throw new UDFArgumentLengthException("Specify two arguments: " + OIs.length);
}
if (!HiveUtils.isNumberListOI(OIs[0])) {
throw new UDFArgumentTypeException(0,
"Only array type argument is acceptable but " + OIs[0].getTypeName()
+ " was passed as `features`");
}
if (!HiveUtils.isListOI(OIs[1]) || !HiveUtils.isIntegerOI(
((ListObjectInspector) OIs[1]).getListElementObjectInspector())) {
throw new UDFArgumentTypeException(1, "Only array type argument is acceptable but "
+ OIs[1].getTypeName() + " was passed as `labels`");
}
return new SignalNoiseRatioUDAFEvaluator();
}
static class SignalNoiseRatioUDAFEvaluator extends GenericUDAFEvaluator {
// PARTIAL1 and COMPLETE
private ListObjectInspector featuresOI;
private PrimitiveObjectInspector featureOI;
private ListObjectInspector labelsOI;
private PrimitiveObjectInspector labelOI;
// PARTIAL2 and FINAL
private StructObjectInspector structOI;
private StructField countsField, meansField, variancesField;
private ListObjectInspector countsOI;
private LongObjectInspector countOI;
private ListObjectInspector meansOI;
private ListObjectInspector meanListOI;
private DoubleObjectInspector meanElemOI;
private ListObjectInspector variancesOI;
private ListObjectInspector varianceListOI;
private DoubleObjectInspector varianceElemOI;
@AggregationType(estimable = true)
static class SignalNoiseRatioAggregationBuffer extends AbstractAggregationBuffer {
long[] counts;
double[][] means;
double[][] variances;
@Override
public int estimate() {
return counts == null ? 0
: SizeOf.LONG * counts.length
+ SizeOf.DOUBLE * means.length * means[0].length
+ SizeOf.DOUBLE * variances.length * variances[0].length;
}
public void init(int nClasses, int nFeatures) {
this.counts = new long[nClasses];
this.means = new double[nClasses][nFeatures];
this.variances = new double[nClasses][nFeatures];
}
public void reset() {
if (counts != null) {
Arrays.fill(counts, 0);
for (double[] mean : means) {
Arrays.fill(mean, 0.d);
}
for (double[] variance : variances) {
Arrays.fill(variance, 0.d);
}
}
}
}
@Override
public ObjectInspector init(Mode mode, ObjectInspector[] OIs) throws HiveException {
super.init(mode, OIs);
// initialize input
if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
this.featuresOI = HiveUtils.asListOI(OIs[0]);
this.featureOI =
HiveUtils.asDoubleCompatibleOI(featuresOI.getListElementObjectInspector());
this.labelsOI = HiveUtils.asListOI(OIs[1]);
this.labelOI = HiveUtils.asIntegerOI(labelsOI.getListElementObjectInspector());
} else {// from partial aggregation
this.structOI = (StructObjectInspector) OIs[0];
this.countsField = structOI.getStructFieldRef("counts");
this.countsOI = HiveUtils.asListOI(countsField.getFieldObjectInspector());
this.countOI = HiveUtils.asLongOI(countsOI.getListElementObjectInspector());
this.meansField = structOI.getStructFieldRef("means");
this.meansOI = HiveUtils.asListOI(meansField.getFieldObjectInspector());
this.meanListOI = HiveUtils.asListOI(meansOI.getListElementObjectInspector());
this.meanElemOI = HiveUtils.asDoubleOI(meanListOI.getListElementObjectInspector());
this.variancesField = structOI.getStructFieldRef("variances");
this.variancesOI = HiveUtils.asListOI(variancesField.getFieldObjectInspector());
this.varianceListOI =
HiveUtils.asListOI(variancesOI.getListElementObjectInspector());
this.varianceElemOI =
HiveUtils.asDoubleOI(varianceListOI.getListElementObjectInspector());
}
// initialize output
if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
List fieldOIs = new ArrayList();
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableLongObjectInspector));
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)));
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)));
return ObjectInspectorFactory.getStandardStructObjectInspector(
Arrays.asList("counts", "means", "variances"), fieldOIs);
} else {// terminate
return ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
}
}
@Override
public AbstractAggregationBuffer getNewAggregationBuffer() throws HiveException {
SignalNoiseRatioAggregationBuffer myAgg = new SignalNoiseRatioAggregationBuffer();
reset(myAgg);
return myAgg;
}
@Override
public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg)
throws HiveException {
SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
myAgg.reset();
}
@Override
public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg,
Object[] parameters) throws HiveException {
final Object featuresObj = parameters[0];
final Object labelsObj = parameters[1];
Preconditions.checkNotNull(featuresObj);
Preconditions.checkNotNull(labelsObj);
final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
final List> labels = labelsOI.getList(labelsObj);
final int nClasses = labels.size();
Preconditions.checkArgument(nClasses >= 2, UDFArgumentException.class);
final List> features = featuresOI.getList(featuresObj);
final int nFeatures = features.size();
Preconditions.checkArgument(nFeatures >= 1, UDFArgumentException.class);
if (myAgg.counts == null) {
myAgg.init(nClasses, nFeatures);
} else {
Preconditions.checkArgument(nClasses == myAgg.counts.length,
UDFArgumentException.class);
Preconditions.checkArgument(nFeatures == myAgg.means[0].length,
UDFArgumentException.class);
}
// incrementally calculates means and variance
// http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf
final int clazz = hotIndex(labels, labelOI);
final long n = myAgg.counts[clazz];
myAgg.counts[clazz]++;
for (int i = 0; i < nFeatures; i++) {
final double x =
PrimitiveObjectInspectorUtils.getDouble(features.get(i), featureOI);
final double meanN = myAgg.means[clazz][i];
final double varianceN = myAgg.variances[clazz][i];
myAgg.means[clazz][i] = (n * meanN + x) / (n + 1.d);
myAgg.variances[clazz][i] =
(n * varianceN + (x - meanN) * (x - myAgg.means[clazz][i])) / (n + 1.d);
}
}
private static int hotIndex(@Nonnull List> labels, PrimitiveObjectInspector labelOI)
throws UDFArgumentException {
final int nClasses = labels.size();
int clazz = -1;
for (int i = 0; i < nClasses; i++) {
final int label = PrimitiveObjectInspectorUtils.getInt(labels.get(i), labelOI);
if (label == 1) {// assumes one hot encoding
if (clazz != -1) {
throw new UDFArgumentException(
"Specify one-hot vectorized array. Multiple hot elements found.");
}
clazz = i;
} else {
if (label != 0) {
throw new UDFArgumentException(
"Assumed one-hot encoding (0/1) but found an invalid label: " + label);
}
}
}
if (clazz == -1) {
throw new UDFArgumentException(
"Specify one-hot vectorized array for label. Hot element not found.");
}
return clazz;
}
@Override
public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object other)
throws HiveException {
if (other == null) {
return;
}
final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
final List> counts =
countsOI.getList(structOI.getStructFieldData(other, countsField));
final List> means = meansOI.getList(structOI.getStructFieldData(other, meansField));
final List> variances =
variancesOI.getList(structOI.getStructFieldData(other, variancesField));
final int nClasses = counts.size();
final int nFeatures = meanListOI.getListLength(means.get(0));
if (myAgg.counts == null) {
myAgg.init(nClasses, nFeatures);
}
for (int i = 0; i < nClasses; i++) {
final long n = myAgg.counts[i];
final long cnt = PrimitiveObjectInspectorUtils.getLong(counts.get(i), countOI);
// no need to merge class `i`
if (cnt == 0) {
continue;
}
final List> mean = meanListOI.getList(means.get(i));
final List> variance = varianceListOI.getList(variances.get(i));
myAgg.counts[i] += cnt;
for (int j = 0; j < nFeatures; j++) {
final double meanN = myAgg.means[i][j];
final double meanM =
PrimitiveObjectInspectorUtils.getDouble(mean.get(j), meanElemOI);
final double varianceN = myAgg.variances[i][j];
final double varianceM = PrimitiveObjectInspectorUtils.getDouble(
variance.get(j), varianceElemOI);
if (n == 0) {// only assign `other` into `myAgg`
myAgg.means[i][j] = meanM;
myAgg.variances[i][j] = varianceM;
} else {
// merge by Chan's method
// http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf
myAgg.means[i][j] = (n * meanN + cnt * meanM) / (double) (n + cnt);
myAgg.variances[i][j] = (varianceN * (n - 1) + varianceM * (cnt - 1)
+ Math.pow(meanN - meanM, 2) * n * cnt / (n + cnt)) / (n + cnt - 1);
}
}
}
}
@Override
public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg)
throws HiveException {
final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
final Object[] partialResult = new Object[3];
partialResult[0] = WritableUtils.toWritableList(myAgg.counts);
final List> means = new ArrayList>();
for (double[] mean : myAgg.means) {
means.add(WritableUtils.toWritableList(mean));
}
partialResult[1] = means;
final List> variances = new ArrayList>();
for (double[] variance : myAgg.variances) {
variances.add(WritableUtils.toWritableList(variance));
}
partialResult[2] = variances;
return partialResult;
}
@Override
public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg)
throws HiveException {
final SignalNoiseRatioAggregationBuffer myAgg = (SignalNoiseRatioAggregationBuffer) agg;
final int nClasses = myAgg.counts.length;
final int nFeatures = myAgg.means[0].length;
// compute SNR among classes for each feature
final double[] result = new double[nFeatures];
final double[] sds = new double[nClasses]; // for memorization
for (int i = 0; i < nFeatures; i++) {
sds[0] = Math.sqrt(myAgg.variances[0][i]);
for (int j = 1; j < nClasses; j++) {
sds[j] = Math.sqrt(myAgg.variances[j][i]);
// `ns[j] == 0` means no feature entry belongs to class `j`. Then, skip the entry.
if (myAgg.counts[j] == 0) {
continue;
}
for (int k = 0; k < j; k++) {
// avoid comparing between classes having only single entry
if (myAgg.counts[k] == 0
|| (myAgg.counts[j] == 1 && myAgg.counts[k] == 1)) {
continue;
}
// SUM(snr) GROUP BY feature
final double snr =
Math.abs(myAgg.means[j][i] - myAgg.means[k][i]) / (sds[j] + sds[k]);
// if `NaN`(when diff between means and both sds are zero, IOW, all related values are equal),
// regard feature `i` as meaningless between class `j` and `k`. So, skip the entry.
if (!Double.isNaN(snr)) {
result[i] += snr; // accept `Infinity`
}
}
}
}
return WritableUtils.toWritableList(result);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy