hivemall.tools.map.MapRouletteUDF Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.tools.map;
import static hivemall.utils.lang.StringUtils.join;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Preconditions;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
/**
* The map_roulette returns a map key based on weighted random sampling of map values.
*/
// @formatter:off
@Description(name = "map_roulette",
value = "_FUNC_(Map map [, (const) int/bigint seed])"
+ " - Returns a map key based on weighted random sampling of map values."
+ " Average of values is used for null values",
extended = "-- `map_roulette(map [, integer seed])` returns key by weighted random selection\n" +
"SELECT \n" +
" map_roulette(to_map(a, b)) -- 25% Tom, 21% Zhang, 54% Wang\n" +
"FROM ( -- see https://issues.apache.org/jira/browse/HIVE-17406\n" +
" select 'Wang' as a, 54 as b\n" +
" union all\n" +
" select 'Zhang' as a, 21 as b\n" +
" union all\n" +
" select 'Tom' as a, 25 as b\n" +
") tmp;\n" +
"> Wang\n" +
"\n" +
"-- Weight random selection with using filling nulls with the average value\n" +
"SELECT\n" +
" map_roulette(map(1, 0.5, 'Wang', null)), -- 50% Wang, 50% 1\n" +
" map_roulette(map(1, 0.5, 'Wang', null, 'Zhang', null)) -- 1/3 Wang, 1/3 1, 1/3 Zhang\n" +
";\n" +
"\n" +
"-- NULL will be returned if every key is null\n" +
"SELECT \n" +
" map_roulette(map()),\n" +
" map_roulette(map(null, null, null, null));\n" +
"> NULL NULL\n" +
"\n" +
"-- Return NULL if all weights are zero\n" +
"SELECT\n" +
" map_roulette(map(1, 0)),\n" +
" map_roulette(map(1, 0, '5', 0))\n" +
";\n" +
"> NULL NULL\n" +
"\n" +
"-- map_roulette does not support non-numeric weights or negative weights.\n" +
"SELECT map_roulette(map('Wong', 'A string', 'Zhao', 2));\n" +
"> HiveException: Error evaluating map_roulette(map('Wong':'A string','Zhao':2))\n" +
"SELECT map_roulette(map('Wong', 'A string', 'Zhao', 2));\n" +
"> UDFArgumentException: Map value must be greather than or equals to zero: -2")
// @formatter:on
@UDFType(deterministic = false, stateful = false) // it is false because it return value base on probability
public final class MapRouletteUDF extends GenericUDF {
private transient MapObjectInspector mapOI;
private transient PrimitiveObjectInspector valueOI;
@Nullable
private transient PrimitiveObjectInspector seedOI;
@Nullable
private transient Random _rand;
@Override
public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
if (argOIs.length != 1 && argOIs.length != 2) {
throw new UDFArgumentLengthException(
"Expected exactly one argument for map_roulette: " + argOIs.length);
}
if (argOIs[0].getCategory() != ObjectInspector.Category.MAP) {
throw new UDFArgumentTypeException(0,
"Only map type argument is accepted but got " + argOIs[0].getTypeName());
}
this.mapOI = HiveUtils.asMapOI(argOIs[0]);
this.valueOI = HiveUtils.asDoubleCompatibleOI(mapOI.getMapValueObjectInspector());
if (argOIs.length == 2) {
ObjectInspector argOI1 = argOIs[1];
if (HiveUtils.isIntegerOI(argOI1) == false) {
throw new UDFArgumentException(
"The second argument of map_roulette must be integer type: "
+ argOI1.getTypeName());
}
if (ObjectInspectorUtils.isConstantObjectInspector(argOI1)) {
long seed = HiveUtils.getAsConstLong(argOI1);
this._rand = new Random(seed); // fixed seed
} else {
this.seedOI = HiveUtils.asLongCompatibleOI(argOI1);
}
} else {
this._rand = new Random(); // random seed
}
return mapOI.getMapKeyObjectInspector();
}
@Nullable
@Override
public Object evaluate(DeferredObject[] arguments) throws HiveException {
Random rand = _rand;
if (rand == null) {
Object arg1 = arguments[1].get();
if (arg1 == null) {
rand = new Random();
} else {
long seed = HiveUtils.getLong(arg1, seedOI);
rand = new Random(seed);
}
}
Map
© 2015 - 2025 Weber Informatics LLC | Privacy Policy