hivemall.tools.array.SelectKBestUDF Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.tools.array;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Preconditions;
import java.io.IOException;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
@Description(name = "select_k_best",
value = "_FUNC_(array array, const array importance, const int k)"
+ " - Returns selected top-k elements as array")
@UDFType(deterministic = true, stateful = false)
public final class SelectKBestUDF extends GenericUDF {
private ListObjectInspector featuresOI;
private PrimitiveObjectInspector featureOI;
private ListObjectInspector importanceListOI;
private PrimitiveObjectInspector importanceElemOI;
private int _k;
private List _result;
private int[] _topKIndices;
@Override
public ObjectInspector initialize(ObjectInspector[] OIs) throws UDFArgumentException {
if (OIs.length != 3) {
throw new UDFArgumentLengthException("Specify three arguments: " + OIs.length);
}
if (!HiveUtils.isNumberListOI(OIs[0])) {
throw new UDFArgumentTypeException(0,
"Only array type argument is acceptable but " + OIs[0].getTypeName()
+ " was passed as `features`");
}
if (!HiveUtils.isNumberListOI(OIs[1])) {
throw new UDFArgumentTypeException(1,
"Only array type argument is acceptable but " + OIs[1].getTypeName()
+ " was passed as `importance_list`");
}
if (!HiveUtils.isIntegerOI(OIs[2])) {
throw new UDFArgumentTypeException(2, "Only int type argument is acceptable but "
+ OIs[2].getTypeName() + " was passed as `k`");
}
this.featuresOI = HiveUtils.asListOI(OIs[0]);
this.featureOI = HiveUtils.asDoubleCompatibleOI(featuresOI.getListElementObjectInspector());
this.importanceListOI = HiveUtils.asListOI(OIs[1]);
this.importanceElemOI =
HiveUtils.asDoubleCompatibleOI(importanceListOI.getListElementObjectInspector());
this._k = HiveUtils.getConstInt(OIs[2]);
Preconditions.checkArgument(_k >= 1, UDFArgumentException.class);
final List result = new ArrayList<>(_k);
for (int i = 0; i < _k; i++) {
result.add(new DoubleWritable());
}
this._result = result;
return ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
}
@Override
public List evaluate(DeferredObject[] dObj) throws HiveException {
final double[] features = HiveUtils.asDoubleArray(dObj[0].get(), featuresOI, featureOI);
final double[] importanceList =
HiveUtils.asDoubleArray(dObj[1].get(), importanceListOI, importanceElemOI);
Preconditions.checkNotNull(features, UDFArgumentException.class);
Preconditions.checkNotNull(importanceList, UDFArgumentException.class);
Preconditions.checkArgument(features.length == importanceList.length,
UDFArgumentException.class);
Preconditions.checkArgument(features.length >= _k, UDFArgumentException.class);
int[] topKIndices = _topKIndices;
if (topKIndices == null) {
final List> list =
new ArrayList>();
for (int i = 0; i < importanceList.length; i++) {
list.add(new AbstractMap.SimpleEntry(i, importanceList[i]));
}
Collections.sort(list, new Comparator>() {
@Override
public int compare(Map.Entry o1, Map.Entry o2) {
return o1.getValue() > o2.getValue() ? -1 : 1;
}
});
topKIndices = new int[_k];
for (int i = 0; i < topKIndices.length; i++) {
topKIndices[i] = list.get(i).getKey();
}
this._topKIndices = topKIndices;
}
final List result = _result;
for (int i = 0; i < topKIndices.length; i++) {
int idx = topKIndices[i];
DoubleWritable d = result.get(i);
double f = features[idx];
d.set(f);
}
return result;
}
@Override
public void close() throws IOException {
// help GC
this._result = null;
this._topKIndices = null;
}
@Override
public String getDisplayString(String[] children) {
final StringBuilder sb = new StringBuilder();
sb.append("select_k_best");
sb.append("(");
if (children.length > 0) {
sb.append(children[0]);
for (int i = 1; i < children.length; i++) {
sb.append(", ");
sb.append(children[i]);
}
}
sb.append(")");
return sb.toString();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy