hivemall.ftvec.ranking.BprSamplingUDTF Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.ftvec.ranking;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Random;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.IntWritable;
import hivemall.UDTFWithOptions;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.BitUtils;
import hivemall.utils.lang.Primitives;
@Description(name = "bpr_sampling",
value = "_FUNC_(int userId, List posItems [, const string options])"
+ "- Returns a relation consists of ")
public final class BprSamplingUDTF extends UDTFWithOptions {
private PrimitiveObjectInspector userOI;
private ListObjectInspector itemListOI;
private PrimitiveObjectInspector itemElemOI;
// Need to avoid
// org.apache.hive.com.esotericsoftware.kryo.KryoException: java.lang.ArrayIndexOutOfBoundsException: 1
@Nullable
private transient PositiveOnlyFeedback feedback;
// sampling options
private int maxItemId;
private float samplingRate;
private boolean withoutReplacement;
private boolean pairSampling;
private Object[] forwardObjs;
private IntWritable userId;
private IntWritable posItemId;
private IntWritable negItemId;
public BprSamplingUDTF() {}
@Override
protected Options getOptions() {
Options opts = new Options();
opts.addOption("sampling", "sampling_rate", true,
"Sampling rates of positive items [default: 1.0]");
opts.addOption("without_replacement", false,
"Do sampling without-replacement sampling [default: false]");
opts.addOption("uniform_pair_sampling", "pair_sampling", false,
"Sampling pairs uniform from feedbacks [default: false]");
opts.addOption("maxcol", "max_itemid", true, "Max item id index [default: -1]");
return opts;
}
@Override
protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs)
throws UDFArgumentException {
CommandLine cl = null;
int maxItemId = -1;
float samplingRate = 1.f;
boolean withoutReplacement = false;
boolean pairSampling = false;
if (argOIs.length >= 3) {
String args = HiveUtils.getConstString(argOIs[2]);
cl = parseOptions(args);
maxItemId = Primitives.parseInt(cl.getOptionValue("max_itemid"), maxItemId);
withoutReplacement = cl.hasOption("without_replacement");
pairSampling = cl.hasOption("uniform_pair_sampling");
samplingRate = Primitives.parseFloat(cl.getOptionValue("sampling_rate"), samplingRate);
if (withoutReplacement && samplingRate > 1.f) {
throw new UDFArgumentException("sampling_rate MUST be in less than or equals to 1"
+ " where without-replacement is true: " + samplingRate);
}
}
this.maxItemId = maxItemId;
this.samplingRate = samplingRate;
this.withoutReplacement = withoutReplacement;
this.pairSampling = pairSampling;
return cl;
}
@Override
public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs)
throws UDFArgumentException {
if (argOIs.length != 2 && argOIs.length != 3) {
throw new UDFArgumentException(
"_FUNC_(int userid, array itemid, [, const string options])"
+ " takes at least two arguments");
}
this.userOI = HiveUtils.asIntegerOI(argOIs[0]);
this.itemListOI = HiveUtils.asListOI(argOIs[1]);
this.itemElemOI = HiveUtils.asIntegerOI(itemListOI.getListElementObjectInspector());
processOptions(argOIs);
this.userId = new IntWritable();
this.posItemId = new IntWritable();
this.negItemId = new IntWritable();
this.forwardObjs = new Object[] {userId, posItemId, negItemId};
ArrayList fieldNames = new ArrayList();
ArrayList fieldOIs = new ArrayList();
fieldNames.add("user");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("pos_item");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("neg_item");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
@Override
public void process(@Nonnull Object[] args) throws HiveException {
if (feedback == null) {
this.feedback = pairSampling ? new PerEventPositiveOnlyFeedback(maxItemId)
: new PositiveOnlyFeedback(maxItemId);
}
int userId = PrimitiveObjectInspectorUtils.getInt(args[0], userOI);
validateIndex(userId);
addFeedback(userId, args[1]);
}
@Nullable
private void addFeedback(final int userId, @Nonnull final Object arg)
throws UDFArgumentException {
final int size = itemListOI.getListLength(arg);
if (size == 0) {
return;
}
int maxItemId = feedback.getMaxItemId();
final IntArrayList posItems = new IntArrayList(size);
for (int i = 0; i < size; i++) {
Object elem = itemListOI.getListElement(arg, i);
if (elem == null) {
continue;
}
int index = PrimitiveObjectInspectorUtils.getInt(elem, itemElemOI);
validateIndex(index);
maxItemId = Math.max(index, maxItemId);
posItems.add(index);
}
feedback.addFeedback(userId, posItems);
feedback.setMaxItemId(maxItemId);
}
@Override
public void close() throws HiveException {
int feedbacks = feedback.getTotalFeedbacks();
if (feedbacks == 0) {
return;
}
int numSamples = (int) (feedbacks * samplingRate);
if (pairSampling) {
PerEventPositiveOnlyFeedback evFeedback = (PerEventPositiveOnlyFeedback) feedback;
if (withoutReplacement) {
uniformPairSamplingWithoutReplacement(evFeedback, numSamples);
} else {
uniformPairSamplingWithReplacement(evFeedback, numSamples);
}
} else {
if (withoutReplacement) {
uniformUserSamplingWithoutReplacement(feedback, numSamples);
} else {
uniformUserSamplingWithReplacement(feedback, numSamples);
}
}
}
private void forward(final int user, final int posItem, final int negItem)
throws HiveException {
assert (user >= 0) : user;
assert (posItem >= 0) : posItem;
assert (negItem >= 0) : negItem;
userId.set(user);
posItemId.set(posItem);
negItemId.set(negItem);
forward(forwardObjs);
}
/**
* Sampling pairs uniform for each user with replacement. Sample a user. Then, sample a pair.
*/
private void uniformUserSamplingWithReplacement(@Nonnull final PositiveOnlyFeedback feedback,
final int numSamples) throws HiveException {
final int numUsers = feedback.getNumUsers();
if (numUsers == 0) {
return;
}
final int maxItemId = feedback.getMaxItemId();
if (maxItemId <= 0) {
throw new HiveException("Invalid maxItemId: " + maxItemId);
}
final int numItems = maxItemId + 1;
final int[] users = feedback.getUsers();
assert (users.length == numUsers);
final Random rand = new Random(31L);
for (int i = 0; i < numSamples; i++) {
int user = users[rand.nextInt(numUsers)];
IntArrayList posItems = feedback.getItems(user, true);
assert (posItems != null) : user;
int size = posItems.size();
assert (size > 0) : size;
if (size == numItems) {// cannot draw a negative item
--i;
continue;
}
int posItemIndex = rand.nextInt(size);
int posItem = posItems.fastGet(posItemIndex);
int negItem;
do {
negItem = rand.nextInt(maxItemId);
} while (posItems.contains(negItem));
forward(user, posItem, negItem);
}
}
/**
* Sampling pairs uniform for each user without replacement. Sample a user. Then, sample a pair.
*
* Caution: This is not a perfect 'without sampling' but it does 'without sampling' for positive
* feedbacks.
*/
private void uniformUserSamplingWithoutReplacement(@Nonnull final PositiveOnlyFeedback feedback,
final int numSamples) throws HiveException {
int numUsers = feedback.getNumUsers();
if (numUsers == 0) {
return;
}
final int maxItemId = feedback.getMaxItemId();
if (maxItemId <= 0) {
throw new HiveException("Invalid maxItemId: " + maxItemId);
}
final int numItems = maxItemId + 1;
final BitSet userBits = new BitSet(numUsers);
feedback.getUsers(userBits);
final Random rand = new Random(31L);
for (int i = 0; i < numSamples && numUsers > 0; i++) {
int nthUser = rand.nextInt(numUsers);
int user = BitUtils.indexOfSetBit(userBits, nthUser);
if (user == -1) {
throw new HiveException(
"Cannot find " + nthUser + "-th user among " + numUsers + " users");
}
IntArrayList posItems = feedback.getItems(user, true);
assert (posItems != null) : user;
int size = posItems.size();
assert (size > 0) : size;
if (size == numItems) {// cannot draw a negative item
--i;
continue;
}
int posItemIndex = rand.nextInt(size);
int posItem = posItems.fastGet(posItemIndex);
int negItem;
do {
negItem = rand.nextInt(maxItemId);
} while (posItems.contains(negItem));
posItems.remove(posItemIndex);
if (posItems.isEmpty()) {
feedback.removeFeedback(user);
userBits.clear(user);
--numUsers;
}
forward(user, posItem, negItem);
}
}
/**
* Sampling pairs uniform from feedbacks with replacement.
*/
private void uniformPairSamplingWithReplacement(
@Nonnull final PerEventPositiveOnlyFeedback feedback, final int numSamples)
throws HiveException {
final int numFeedbacks = feedback.getTotalFeedbacks();
if (numFeedbacks == 0) {
return;
}
final int maxItemId = feedback.getMaxItemId();
if (maxItemId <= 0) {
throw new HiveException("Invalid maxItemId: " + maxItemId);
}
final Random rand = new Random(31L);
for (int i = 0; i < numSamples; i++) {
int index = rand.nextInt(numFeedbacks);
int user = feedback.getUser(index);
int posItem = feedback.getPositiveItem(index);
IntArrayList posItems = feedback.getItems(user, true);
assert (posItems != null) : user;
int negItem;
do {
negItem = rand.nextInt(maxItemId);
} while (posItems.contains(negItem));
forward(user, posItem, negItem);
}
}
/**
* Sampling pairs uniform from feedbacks without replacement.
*
* Caution: This is not a perfect 'without sampling' but it does 'without sampling' for positive
* feedbacks.
*/
private void uniformPairSamplingWithoutReplacement(
@Nonnull final PerEventPositiveOnlyFeedback feedback, final int numSamples)
throws HiveException {
final int numFeedbacks = feedback.getTotalFeedbacks();
if (numFeedbacks == 0) {
return;
}
final int maxItemId = feedback.getMaxItemId();
if (maxItemId <= 0) {
throw new HiveException("Invalid maxItemId: " + maxItemId);
}
final Random rand = new Random(31L);
final int[] perm = feedback.getRandomIndex(rand);
for (int index : perm) {
int user = feedback.getUser(index);
int posItem = feedback.getPositiveItem(index);
IntArrayList posItems = feedback.getItems(user, true);
assert (posItems != null) : user;
int negItem;
do {
negItem = rand.nextInt(maxItemId);
} while (posItems.contains(negItem));
forward(user, posItem, negItem);
}
}
private static void validateIndex(final int index) throws UDFArgumentException {
if (index < 0) {
throw new UDFArgumentException("Negative index is not allowed: " + index);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy