hivemall.factorization.mf.OnlineMatrixFactorizationUDTF Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.factorization.mf;
import hivemall.UDTFWithOptions;
import hivemall.common.ConversionState;
import hivemall.factorization.mf.FactorizedModel.RankInitScheme;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.FileUtils;
import hivemall.utils.io.NioFixedSegment;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapred.Counters.Counter;
import org.apache.hadoop.mapred.Reporter;
public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions
implements RatingInitializer {
private static final Log logger = LogFactory.getLog(OnlineMatrixFactorizationUDTF.class);
private static final int RECORD_BYTES = (Integer.SIZE + Integer.SIZE + Double.SIZE) / 8;
// Option variables
/** The number of latent factors */
protected int factor;
/** The regularization factor */
protected float lambda;
/** The initial mean rating */
protected float meanRating;
/** Whether update (and return) the mean rating or not */
protected boolean updateMeanRating;
/** The number of iterations */
protected int iterations;
/** Whether to use bias clause */
protected boolean useBiasClause;
/** Initialization strategy of rank matrix */
protected RankInitScheme rankInit;
// Model itself
protected FactorizedModel model;
// Variable managing status of learning
/** The number of processed training examples */
protected long count;
protected ConversionState cvState;
// Input OIs and Context
protected PrimitiveObjectInspector userOI;
protected PrimitiveObjectInspector itemOI;
protected PrimitiveObjectInspector ratingOI;
// Used for iterations
protected NioFixedSegment fileIO;
protected ByteBuffer inputBuf;
private long lastWritePos;
private float[] userProbe, itemProbe;
public OnlineMatrixFactorizationUDTF() {
this.factor = 10;
this.lambda = 0.03f;
this.meanRating = 0.f;
this.updateMeanRating = false;
this.iterations = 1;
this.useBiasClause = true;
}
@Override
protected Options getOptions() {
Options opts = new Options();
opts.addOption("k", "factor", true, "The number of latent factor [default: 10] "
+ " Note this is alias for `factors` option.");
opts.addOption("f", "factors", true, "The number of latent factor [default: 10]");
opts.addOption("r", "lambda", true, "The regularization factor [default: 0.03]");
opts.addOption("mu", "mean_rating", true, "The mean rating [default: 0.0]");
opts.addOption("update_mean", "update_mu", false,
"Whether update (and return) the mean rating or not");
opts.addOption("rankinit", true,
"Initialization strategy of rank matrix [random, gaussian] (default: random)");
opts.addOption("maxval", "max_init_value", true,
"The maximum initial value in the rank matrix [default: 1.0]");
opts.addOption("min_init_stddev", true,
"The minimum standard deviation of initial rank matrix [default: 0.1]");
opts.addOption("iters", "iterations", true, "The number of iterations [default: 1]");
opts.addOption("iter", true,
"The number of iterations [default: 1] Alias for `-iterations`");
opts.addOption("disable_cv", "disable_cvtest", false,
"Whether to disable convergence check [default: enabled]");
opts.addOption("cv_rate", "convergence_rate", true,
"Threshold to determine convergence [default: 0.005]");
opts.addOption("disable_bias", "no_bias", false, "Turn off bias clause");
return opts;
}
@Override
protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
CommandLine cl = null;
String rankInitOpt = null;
float maxInitValue = 1.f;
double initStdDev = 0.1d;
boolean conversionCheck = true;
double convergenceRate = 0.005d;
if (argOIs.length >= 4) {
String rawArgs = HiveUtils.getConstString(argOIs, 3);
cl = parseOptions(rawArgs);
if (cl.hasOption("factors")) {
this.factor = Primitives.parseInt(cl.getOptionValue("factors"), 10);
} else {
this.factor = Primitives.parseInt(cl.getOptionValue("factor"), 10);
}
this.lambda = Primitives.parseFloat(cl.getOptionValue("lambda"), 0.03f);
this.meanRating = Primitives.parseFloat(cl.getOptionValue("mu"), 0.f);
this.updateMeanRating = cl.hasOption("update_mean");
rankInitOpt = cl.getOptionValue("rankinit");
maxInitValue = Primitives.parseFloat(cl.getOptionValue("max_init_value"), 1.f);
initStdDev = Primitives.parseDouble(cl.getOptionValue("min_init_stddev"), 0.1d);
if (cl.hasOption("iter")) {
this.iterations = Primitives.parseInt(cl.getOptionValue("iter"), 1);
} else {
this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 1);
}
if (iterations < 1) {
throw new UDFArgumentException(
"'-iterations' must be greater than or equal to 1: " + iterations);
}
conversionCheck = !cl.hasOption("disable_cvtest");
convergenceRate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate);
boolean noBias = cl.hasOption("no_bias");
this.useBiasClause = !noBias;
if (noBias && updateMeanRating) {
throw new UDFArgumentException(
"Cannot set both `update_mean` and `no_bias` option");
}
}
this.rankInit = RankInitScheme.resolve(rankInitOpt);
rankInit.setMaxInitValue(maxInitValue);
initStdDev = Math.max(initStdDev, 1.0d / factor);
rankInit.setInitStdDev(initStdDev);
this.cvState = new ConversionState(conversionCheck, convergenceRate);
return cl;
}
@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
if (argOIs.length < 3) {
showHelp(String.format(
"%s takes 3 or more arguments: INT user, INT item, FLOAT rating [, CONSTANT STRING options]: %s",
getClass().getSimpleName(), Arrays.toString(argOIs)));
}
this.userOI = HiveUtils.asIntCompatibleOI(argOIs, 0);
this.itemOI = HiveUtils.asIntCompatibleOI(argOIs, 1);
this.ratingOI = HiveUtils.asDoubleCompatibleOI(argOIs, 2);
processOptions(argOIs);
this.model = new FactorizedModel(this, factor, meanRating, rankInit);
this.count = 0L;
this.lastWritePos = 0L;
this.userProbe = new float[factor];
this.itemProbe = new float[factor];
if (mapredContext != null && iterations > 1) {
// invoke only at task node (initialize is also invoked in compilation)
final File file;
try {
file = File.createTempFile("hivemall_mf", ".sgmt");
file.deleteOnExit();
if (!file.canWrite()) {
throw new UDFArgumentException(
"Cannot write a temporary file: " + file.getAbsolutePath());
}
} catch (IOException ioe) {
throw new UDFArgumentException(ioe);
} catch (Throwable e) {
throw new UDFArgumentException(e);
}
this.fileIO = new NioFixedSegment(file, RECORD_BYTES, false);
this.inputBuf = ByteBuffer.allocateDirect(65536); // 64 KiB
}
ArrayList fieldNames = new ArrayList();
ArrayList fieldOIs = new ArrayList();
fieldNames.add("idx");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("Pu");
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
fieldNames.add("Qi");
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
if (useBiasClause) {
fieldNames.add("Bu");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
fieldNames.add("Bi");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
if (updateMeanRating) {
fieldNames.add("mu");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
}
}
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
@Override
public Rating newRating(float v) {
return new Rating(v);
}
@Override
public final void process(Object[] args) throws HiveException {
assert (args.length >= 3) : args.length;
int user = PrimitiveObjectInspectorUtils.getInt(args[0], userOI);
if (user < 0) {
throw new HiveException("Illegal user index: " + user);
}
int item = PrimitiveObjectInspectorUtils.getInt(args[1], itemOI);
if (item < 0) {
throw new HiveException("Illegal item index: " + item);
}
double rating = PrimitiveObjectInspectorUtils.getDouble(args[2], ratingOI);
beforeTrain(count, user, item, rating);
count++;
train(user, item, rating);
}
@Nonnull
protected final float[] copyToUserProbe(@Nonnull final Rating[] rating) {
for (int k = 0, size = factor; k < size; k++) {
userProbe[k] = rating[k].getWeight();
}
return userProbe;
}
@Nonnull
protected final float[] copyToItemProbe(@Nonnull final Rating[] rating) {
for (int k = 0, size = factor; k < size; k++) {
itemProbe[k] = rating[k].getWeight();
}
return itemProbe;
}
protected void train(final int user, final int item, final double rating) throws HiveException {
final Rating[] users = model.getUserVector(user, true);
assert (users != null);
final Rating[] items = model.getItemVector(item, true);
assert (items != null);
final float[] userProbe = copyToUserProbe(users);
final float[] itemProbe = copyToItemProbe(items);
final double err = rating - predict(user, item, userProbe, itemProbe);
cvState.incrError(Math.abs(err));
cvState.incrLoss(err * err);
final float eta = eta();
for (int k = 0, size = factor; k < size; k++) {
float Pu = userProbe[k];
float Qi = itemProbe[k];
updateItemRating(items[k], Pu, Qi, err, eta);
updateUserRating(users[k], Pu, Qi, err, eta);
}
if (useBiasClause) {
updateBias(user, item, err, eta);
if (updateMeanRating) {
updateMeanRating(err, eta);
}
}
onUpdate(user, item, users, items, err);
}
protected void beforeTrain(final long rowNum, final int user, final int item,
final double rating) throws HiveException {
if (inputBuf != null) {
assert (fileIO != null);
final ByteBuffer buf = inputBuf;
int remain = buf.remaining();
if (remain < RECORD_BYTES) {
writeBuffer(buf, fileIO, lastWritePos);
this.lastWritePos = rowNum;
}
buf.putInt(user);
buf.putInt(item);
buf.putDouble(rating);
}
}
protected void onUpdate(final int user, final int item, final Rating[] users,
final Rating[] items, final double err) throws HiveException {}
protected double predict(final int user, final int item, final float[] userProbe,
final float[] itemProbe) {
double ret = bias(user, item);
for (int k = 0, size = factor; k < size; k++) {
ret += userProbe[k] * itemProbe[k];
}
return ret;
}
protected double predict(final int user, final int item) throws HiveException {
final Rating[] users = model.getUserVector(user);
if (users == null) {
throw new HiveException("User rating is not found: " + user);
}
final Rating[] items = model.getItemVector(item);
if (items == null) {
throw new HiveException("Item rating is not found: " + item);
}
double ret = bias(user, item);
for (int k = 0, size = factor; k < size; k++) {
ret += users[k].getWeight() * items[k].getWeight();
}
return ret;
}
protected double bias(final int user, final int item) {
if (useBiasClause == false) {
return model.getMeanRating();
}
return model.getMeanRating() + model.getUserBias(user) + model.getItemBias(item);
}
protected float eta() {
return 1.f; // dummy
}
protected void updateItemRating(final Rating rating, final float Pu, final float Qi,
final double err, final float eta) {
double grad = err * Pu - lambda * Qi;
float newQi = Qi + (float) (eta * grad);
rating.setWeight(newQi);
cvState.incrLoss(lambda * Qi * Qi);
}
protected void updateUserRating(final Rating rating, final float Pu, final float Qi,
final double err, final float eta) {
double grad = err * Qi - lambda * Pu;
float newPu = Pu + (float) (eta * grad);
rating.setWeight(newPu);
cvState.incrLoss(lambda * Pu * Pu);
}
protected void updateMeanRating(final double err, final float eta) {
assert updateMeanRating;
float mean = model.getMeanRating();
mean += eta * err;
model.setMeanRating(mean);
}
protected void updateBias(final int user, final int item, final double err, final float eta) {
assert useBiasClause;
float Bu = model.getUserBias(user);
double Gu = err - lambda * Bu;
Bu += eta * Gu;
model.setUserBias(user, Bu);
cvState.incrLoss(lambda * Bu * Bu);
float Bi = model.getItemBias(item);
double Gi = err - lambda * Bi;
Bi += eta * Gi;
model.setItemBias(item, Bi);
cvState.incrLoss(lambda * Bi * Bi);
}
@Override
public void close() throws HiveException {
if (model != null) {
if (count == 0) {
this.model = null; // help GC
return;
}
if (iterations > 1) {
runIterativeTraining(iterations);
}
final IntWritable idx = new IntWritable();
final FloatWritable[] Pu = HiveUtils.newFloatArray(factor, 0.f);
final FloatWritable[] Qi = HiveUtils.newFloatArray(factor, 0.f);
final FloatWritable Bu = new FloatWritable();
final FloatWritable Bi = new FloatWritable();
final Object[] forwardObj;
if (updateMeanRating) {
assert useBiasClause;
float meanRating = model.getMeanRating();
FloatWritable mu = new FloatWritable(meanRating);
forwardObj = new Object[] {idx, Pu, Qi, Bu, Bi, mu};
} else {
if (useBiasClause) {
forwardObj = new Object[] {idx, Pu, Qi, Bu, Bi};
} else {
forwardObj = new Object[] {idx, Pu, Qi};
}
}
int numForwarded = 0;
for (int i = model.getMinIndex(), maxIdx = model.getMaxIndex(); i <= maxIdx; i++) {
idx.set(i);
Rating[] userRatings = model.getUserVector(i);
if (userRatings == null) {
forwardObj[1] = null;
} else {
forwardObj[1] = Pu;
copyTo(userRatings, Pu);
}
Rating[] itemRatings = model.getItemVector(i);
if (itemRatings == null) {
forwardObj[2] = null;
} else {
forwardObj[2] = Qi;
copyTo(itemRatings, Qi);
}
if (useBiasClause) {
Bu.set(model.getUserBias(i));
Bi.set(model.getItemBias(i));
}
forward(forwardObj);
numForwarded++;
}
this.model = null; // help GC
logger.info("Forwarded the prediction model of " + numForwarded + " rows. [totalErrors="
+ cvState.getTotalErrors() + ", lastLosses=" + cvState.getCumulativeLoss()
+ ", #trainingExamples=" + count + "]");
}
}
protected static void writeBuffer(@Nonnull final ByteBuffer srcBuf,
@Nonnull final NioFixedSegment dst, final long lastWritePos) throws HiveException {
// TODO asynchronous write in the background
srcBuf.flip();
try {
dst.writeRecords(lastWritePos, srcBuf);
} catch (IOException e) {
throw new HiveException("Exception causes while writing records to : " + lastWritePos,
e);
}
srcBuf.clear();
}
protected final void runIterativeTraining(@Nonnegative final int iterations)
throws HiveException {
final ByteBuffer inputBuf = this.inputBuf;
final NioFixedSegment fileIO = this.fileIO;
assert (inputBuf != null);
assert (fileIO != null);
final long numTrainingExamples = count;
final Reporter reporter = getReporter();
final Counter iterCounter = (reporter == null) ? null
: reporter.getCounter("hivemall.factorization.mf.MatrixFactorization$Counter",
"iteration");
try {
if (lastWritePos == 0) {// run iterations w/o temporary file
if (inputBuf.position() == 0) {
return; // no training example
}
inputBuf.flip();
for (int iter = 2; iter <= iterations; iter++) {
cvState.next();
reportProgress(reporter);
setCounterValue(iterCounter, iter);
while (inputBuf.remaining() > 0) {
int user = inputBuf.getInt();
int item = inputBuf.getInt();
double rating = inputBuf.getDouble();
// invoke train
count++;
train(user, item, rating);
}
cvState.multiplyLoss(0.5d);
if (cvState.isConverged(numTrainingExamples)) {
break;
}
inputBuf.rewind();
}
logger.info("Performed " + cvState.getCurrentIteration() + " iterations of "
+ NumberUtils.formatNumber(numTrainingExamples)
+ " training examples on memory (thus " + NumberUtils.formatNumber(count)
+ " training updates in total) ");
} else {// read training examples in the temporary file and invoke train for each example
// write training examples in buffer to a temporary file
if (inputBuf.position() > 0) {
writeBuffer(inputBuf, fileIO, lastWritePos);
}
try {
fileIO.flush();
} catch (IOException e) {
throw new HiveException(
"Failed to flush a file: " + fileIO.getFile().getAbsolutePath(), e);
}
if (logger.isInfoEnabled()) {
File tmpFile = fileIO.getFile();
logger.info("Wrote " + numTrainingExamples
+ " records to a temporary file for iterative training: "
+ tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile)
+ ")");
}
// run iterations
for (int iter = 2; iter <= iterations; iter++) {
cvState.next();
setCounterValue(iterCounter, iter);
inputBuf.clear();
long seekPos = 0L;
while (true) {
reportProgress(reporter);
// TODO prefetch
// writes training examples to a buffer in the temporary file
final int bytesRead;
try {
bytesRead = fileIO.read(seekPos, inputBuf);
} catch (IOException e) {
throw new HiveException(
"Failed to read a file: " + fileIO.getFile().getAbsolutePath(), e);
}
if (bytesRead == 0) { // reached file EOF
break;
}
assert (bytesRead > 0) : bytesRead;
seekPos += bytesRead;
// reads training examples from a buffer
inputBuf.flip();
int remain = inputBuf.remaining();
assert (remain > 0) : remain;
for (; remain >= RECORD_BYTES; remain -= RECORD_BYTES) {
int user = inputBuf.getInt();
int item = inputBuf.getInt();
double rating = inputBuf.getDouble();
// invoke train
count++;
train(user, item, rating);
}
inputBuf.compact();
}
cvState.multiplyLoss(0.5d);
if (cvState.isConverged(numTrainingExamples)) {
break;
}
}
logger.info("Performed " + cvState.getCurrentIteration() + " iterations of "
+ NumberUtils.formatNumber(numTrainingExamples)
+ " training examples using a secondary storage (thus "
+ NumberUtils.formatNumber(count) + " training updates in total)");
}
} finally {
// delete the temporary file and release resources
try {
fileIO.close(true);
} catch (IOException e) {
throw new HiveException(
"Failed to close a file: " + fileIO.getFile().getAbsolutePath(), e);
}
this.inputBuf = null;
this.fileIO = null;
}
}
private static void copyTo(@Nonnull final Rating[] rating, @Nonnull final FloatWritable[] dst) {
for (int k = 0, size = rating.length; k < size; k++) {
float w = rating[k].getWeight();
dst[k].set(w);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy