* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
package hivemall.recommend;
import hivemall.UDTFWithOptions;
import hivemall.annotations.VisibleForTesting;
import hivemall.common.ConversionState;
import matrix4j.matrix.FloatMatrix;
import matrix4j.matrix.sparse.floats.DoKFloatMatrix;
import matrix4j.vector.VectorProcedure;
import hivemall.utils.collections.Fastutil;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.SizeOf;
import hivemall.utils.lang.mutable.MutableDouble;
import hivemall.utils.lang.mutable.MutableInt;
import hivemall.utils.lang.mutable.MutableObject;
import it.unimi.dsi.fastutil.ints.Int2FloatMap;
import it.unimi.dsi.fastutil.ints.Int2FloatOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.mapred.Counters;
import org.apache.hadoop.mapred.Reporter;
* Sparse Linear Methods (SLIM) for Top-N Recommender Systems.
* Xia Ning and George Karypis, SLIM: Sparse Linear Methods for Top-N Recommender Systems, Proc. ICDM, 2011.
@Description(name = "train_slim",
value = "_FUNC_( int i, map r_i, map> topKRatesOfI, int j, map r_j [, constant string options]) "
+ "- Returns row index, column index and non-zero weight value of prediction model")
public class SlimUDTF extends UDTFWithOptions {
private static final Log logger = LogFactory.getLog(SlimUDTF.class);
// input OIs
private PrimitiveObjectInspector itemIOI;
private PrimitiveObjectInspector itemJOI;
private MapObjectInspector riOI;
private MapObjectInspector rjOI;
private MapObjectInspector knnItemsOI;
private PrimitiveObjectInspector knnItemsKeyOI;
private MapObjectInspector knnItemsValueOI;
private PrimitiveObjectInspector knnItemsValueKeyOI;
private PrimitiveObjectInspector knnItemsValueValueOI;
private PrimitiveObjectInspector riKeyOI;
private PrimitiveObjectInspector riValueOI;
private PrimitiveObjectInspector rjKeyOI;
private PrimitiveObjectInspector rjValueOI;
// hyperparameters
private double l1;
private double l2;
private int numIterations;
// model parameters and else
/** item-item weight matrix */
private transient DoKFloatMatrix _weightMatrix;
// caching for each item i
private int _previousItemId;
private transient Int2FloatMap _ri;
private transient Int2ObjectMap _kNNi;
/** The number of elements in kNNi */
private transient MutableInt _nnzKNNi;
// variables for iteration supports
/** item-user matrix holding the input data */
private transient FloatMatrix _dataMatrix;
// used to store KNN data into temporary file for iterative training
private transient NioStatefulSegment _fileIO;
private transient ByteBuffer _inputBuf;
private ConversionState _cvState;
private long _observedTrainingExamples;
public SlimUDTF() {}
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
final int numArgs = argOIs.length;
if (numArgs == 1 && HiveUtils.isConstString(argOIs[0])) {// for -help option
String rawArgs = HiveUtils.getConstString(argOIs[0]);
if (numArgs != 5 && numArgs != 6) {
throw new UDFArgumentException(
"_FUNC_ takes 5 or 6 arguments: int i, map r_i, map> topKRatesOfI, int j, map r_j [, constant string options]: "
+ Arrays.toString(argOIs));
this.itemIOI = HiveUtils.asIntCompatibleOI(argOIs[0]);
this.riOI = HiveUtils.asMapOI(argOIs[1]);
this.riKeyOI = HiveUtils.asIntCompatibleOI((riOI.getMapKeyObjectInspector()));
this.riValueOI = HiveUtils.asPrimitiveObjectInspector((riOI.getMapValueObjectInspector()));
this.knnItemsOI = HiveUtils.asMapOI(argOIs[2]);
this.knnItemsKeyOI = HiveUtils.asIntCompatibleOI(knnItemsOI.getMapKeyObjectInspector());
this.knnItemsValueOI = HiveUtils.asMapOI(knnItemsOI.getMapValueObjectInspector());
this.knnItemsValueKeyOI =
this.knnItemsValueValueOI =
this.itemJOI = HiveUtils.asIntCompatibleOI(argOIs[3]);
this.rjOI = HiveUtils.asMapOI(argOIs[4]);
this.rjKeyOI = HiveUtils.asIntCompatibleOI((rjOI.getMapKeyObjectInspector()));
this.rjValueOI = HiveUtils.asPrimitiveObjectInspector((rjOI.getMapValueObjectInspector()));
this._observedTrainingExamples = 0L;
this._previousItemId = Integer.MIN_VALUE;
this._weightMatrix = null;
this._dataMatrix = null;
List fieldNames = new ArrayList<>();
List fieldOIs = new ArrayList<>();
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
protected Options getOptions() {
Options opts = new Options();
opts.addOption("l1", "l1coefficient", true,
"Coefficient for l1 regularizer [default: 0.001]");
opts.addOption("l2", "l2coefficient", true,
"Coefficient for l2 regularizer [default: 0.0005]");
opts.addOption("iters", "iterations", true,
"The number of iterations for coordinate descent [default: 30]");
opts.addOption("disable_cv", "disable_cvtest", false,
"Whether to disable convergence check [default: enabled]");
opts.addOption("cv_rate", "convergence_rate", true,
"Threshold to determine convergence [default: 0.005]");
return opts;
protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs)
throws UDFArgumentException {
CommandLine cl = null;
double l1 = 0.001d;
double l2 = 0.0005d;
int numIterations = 30;
boolean conversionCheck = true;
double cv_rate = 0.005d;
if (argOIs.length >= 6) {
String rawArgs = HiveUtils.getConstString(argOIs[5]);
cl = parseOptions(rawArgs);
l1 = Primitives.parseDouble(cl.getOptionValue("l1"), l1);
if (l1 < 0.d) {
throw new UDFArgumentException("Argument `double l1` must be non-negative: " + l1);
l2 = Primitives.parseDouble(cl.getOptionValue("l2"), l2);
if (l2 < 0.d) {
throw new UDFArgumentException("Argument `double l2` must be non-negative: " + l2);
numIterations = Primitives.parseInt(cl.getOptionValue("iters"), numIterations);
if (numIterations <= 0) {
throw new UDFArgumentException(
"Argument `int iters` must be greater than 0: " + numIterations);
conversionCheck = !cl.hasOption("disable_cvtest");
cv_rate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), cv_rate);
if (cv_rate <= 0) {
throw new UDFArgumentException(
"Argument `double cv_rate` must be greater than 0.0: " + cv_rate);
this.l1 = l1;
this.l2 = l2;
this.numIterations = numIterations;
this._cvState = new ConversionState(conversionCheck, cv_rate);
return cl;
public void process(@Nonnull Object[] args) throws HiveException {
if (_weightMatrix == null) {// initialize variables
this._weightMatrix = new DoKFloatMatrix();
if (numIterations >= 2) {
this._dataMatrix = new DoKFloatMatrix();
this._nnzKNNi = new MutableInt();
final int itemI = PrimitiveObjectInspectorUtils.getInt(args[0], itemIOI);
if (itemI != _previousItemId || _ri == null) {
// cache Ri and kNNi
this._ri =
int2floatMap(itemI, riOI.getMap(args[1]), riKeyOI, riValueOI, _dataMatrix, _ri);
this._kNNi = kNNentries(args[2], knnItemsOI, knnItemsKeyOI, knnItemsValueOI,
knnItemsValueKeyOI, knnItemsValueValueOI, _kNNi, _nnzKNNi);
final int numKNNItems = _nnzKNNi.getValue();
if (numIterations >= 2 && numKNNItems >= 1) {
recordTrainingInput(itemI, _kNNi, numKNNItems);
this._previousItemId = itemI;
int itemJ = PrimitiveObjectInspectorUtils.getInt(args[3], itemJOI);
Int2FloatMap rj =
int2floatMap(itemJ, rjOI.getMap(args[4]), rjKeyOI, rjValueOI, _dataMatrix);
train(itemI, _ri, _kNNi, itemJ, rj);
private void recordTrainingInput(final int itemI,
@Nonnull final Int2ObjectMap knnItems, final int numKNNItems)
throws HiveException {
ByteBuffer buf = this._inputBuf;
NioStatefulSegment dst = this._fileIO;
if (buf == null) {
// invoke only at task node (initialize is also invoked in compilation)
final File file;
try {
file = File.createTempFile("hivemall_slim", ".sgmt"); // to save KNN data
if (!file.canWrite()) {
throw new UDFArgumentException(
"Cannot write a temporary file: " + file.getAbsolutePath());
} catch (IOException ioe) {
throw new UDFArgumentException(ioe);
this._inputBuf = buf = ByteBuffer.allocateDirect(8 * 1024 * 1024); // 8MB
this._fileIO = dst = new NioStatefulSegment(file, false);
int recordBytes = SizeOf.INT + SizeOf.INT + SizeOf.INT * 2 * knnItems.size()
+ (SizeOf.INT + SizeOf.FLOAT) * numKNNItems;
int requiredBytes = SizeOf.INT + recordBytes; // need to allocate space for "recordBytes" itself
int remain = buf.remaining();
if (remain < requiredBytes) {
writeBuffer(buf, dst);
for (Int2ObjectMap.Entry e1 : Fastutil.fastIterable(knnItems)) {
int user = e1.getIntKey();
Int2FloatMap ru = e1.getValue();
for (Int2FloatMap.Entry e2 : Fastutil.fastIterable(ru)) {
private static void writeBuffer(@Nonnull final ByteBuffer srcBuf,
@Nonnull final NioStatefulSegment dst) throws HiveException {
try {
} catch (IOException e) {
throw new HiveException("Exception causes while writing a buffer to file", e);
private void train(final int itemI, @Nonnull final Int2FloatMap ri,
@Nonnull final Int2ObjectMap kNNi, final int itemJ,
@Nonnull final Int2FloatMap rj) {
final FloatMatrix W = _weightMatrix;
final int N = rj.size();
if (N == 0) {
double gradSum = 0.d;
double rateSum = 0.d;
double lossSum = 0.d;
for (Int2FloatMap.Entry e : Fastutil.fastIterable(rj)) {
int user = e.getIntKey();
double ruj = e.getFloatValue();
double rui = ri.get(user); // ri.getOrDefault(user, 0.f);
double eui = rui - predict(user, itemI, kNNi, itemJ, W);
gradSum += ruj * eui;
rateSum += ruj * ruj;
lossSum += eui * eui;
gradSum /= N;
rateSum /= N;
double wij = W.get(itemI, itemJ, 0.d);
double loss = lossSum / N + 0.5d * l2 * wij * wij + l1 * wij;
W.set(itemI, itemJ, getUpdateTerm(gradSum, rateSum, l1, l2));
private void train(final int itemI, @Nonnull final Int2ObjectMap knnItems,
final int itemJ) {
final FloatMatrix A = _dataMatrix;
final FloatMatrix W = _weightMatrix;
final int N = A.numColumns(itemJ);
if (N == 0) {
final MutableDouble mutableGradSum = new MutableDouble(0.d);
final MutableDouble mutableRateSum = new MutableDouble(0.d);
final MutableDouble mutableLossSum = new MutableDouble(0.d);
A.eachNonZeroInRow(itemJ, new VectorProcedure() {
public void apply(int user, double ruj) {
double rui = A.get(itemI, user, 0.d);
double eui = rui - predict(user, itemI, knnItems, itemJ, W);
mutableGradSum.addValue(ruj * eui);
mutableRateSum.addValue(ruj * ruj);
mutableLossSum.addValue(eui * eui);
double gradSum = mutableGradSum.getValue() / N;
double rateSum = mutableRateSum.getValue() / N;
double wij = W.get(itemI, itemJ, 0.d);
double loss = mutableLossSum.getValue() / N + 0.5 * l2 * wij * wij + l1 * wij;
W.set(itemI, itemJ, getUpdateTerm(gradSum, rateSum, l1, l2));
private static double predict(final int user, final int itemI,
@Nonnull final Int2ObjectMap knnItems, final int excludeIndex,
@Nonnull final FloatMatrix weightMatrix) {
final Int2FloatMap kNNu = knnItems.get(user);
if (kNNu == null) {
return 0.d;
double pred = 0.d;
for (Int2FloatMap.Entry e : Fastutil.fastIterable(kNNu)) {
final int itemK = e.getIntKey();
if (itemK == excludeIndex) {
float ruk = e.getFloatValue();
pred += ruk * weightMatrix.get(itemI, itemK, 0.d);
return pred;
private static double getUpdateTerm(final double gradSum, final double rateSum, final double l1,
final double l2) {
double update = 0.d;
if (Math.abs(gradSum) > l1) {
if (gradSum > 0.d) {
update = (gradSum - l1) / (rateSum + l2);
} else {
update = (gradSum + l1) / (rateSum + l2);
// non-negative constraints
if (update < 0.d) {
update = 0.d;
return update;
public void close() throws HiveException {
this._weightMatrix = null;
void finalizeTraining() throws HiveException {
if (numIterations > 1) {
this._ri = null;
this._kNNi = null;
this._dataMatrix = null;
private void runIterativeTraining() throws HiveException {
final ByteBuffer buf = this._inputBuf;
final NioStatefulSegment dst = this._fileIO;
assert (buf != null);
assert (dst != null);
final Reporter reporter = getReporter();
final Counters.Counter iterCounter = (reporter == null) ? null
: reporter.getCounter("hivemall.recommend.slim$Counter", "iteration");
try {
if (dst.getPosition() == 0L) {// run iterations w/o temporary file
if (buf.position() == 0) {
return; // no training example
for (int iter = 2; iter < numIterations; iter++) {;
setCounterValue(iterCounter, iter);
while (buf.remaining() > 0) {
int recordBytes = buf.getInt();
assert (recordBytes > 0) : recordBytes;
if (_cvState.isConverged(_observedTrainingExamples)) {
}"Performed " + _cvState.getCurrentIteration() + " iterations of "
+ NumberUtils.formatNumber(_observedTrainingExamples)
+ " training examples on memory (thus "
+ NumberUtils.formatNumber(
_observedTrainingExamples * _cvState.getCurrentIteration())
+ " training updates in total) ");
} else { // read training examples in the temporary file and invoke train for each example
// write KNNi in buffer to a temporary file
if (buf.remaining() > 0) {
writeBuffer(buf, dst);
try {
} catch (IOException e) {
throw new HiveException(
"Failed to flush a file: " + dst.getFile().getAbsolutePath(), e);
if (logger.isInfoEnabled()) {
File tmpFile = dst.getFile();
"Wrote KNN entries of axis items to a temporary file for iterative training: "
+ tmpFile.getAbsolutePath() + " ("
+ FileUtils.prettyFileSize(tmpFile) + ")");
// run iterations
for (int iter = 2; iter < numIterations; iter++) {;
setCounterValue(iterCounter, iter);
while (true) {
// load a KNNi to a buffer in the temporary file
final int bytesRead;
try {
bytesRead =;
} catch (IOException e) {
throw new HiveException(
"Failed to read a file: " + dst.getFile().getAbsolutePath(), e);
if (bytesRead == 0) { // reached file EOF
assert (bytesRead > 0) : bytesRead;
// reads training examples from a buffer
int remain = buf.remaining();
if (remain < SizeOf.INT) {
throw new HiveException("Illegal file format was detected");
while (remain >= SizeOf.INT) {
int pos = buf.position();
int recordBytes = buf.getInt();
remain -= SizeOf.INT;
if (remain < recordBytes) {
remain -= recordBytes;
if (_cvState.isConverged(_observedTrainingExamples)) {
}"Performed " + _cvState.getCurrentIteration() + " iterations of "
+ NumberUtils.formatNumber(_observedTrainingExamples)
+ " training examples on memory and KNNi data on secondary storage (thus "
+ NumberUtils.formatNumber(
_observedTrainingExamples * _cvState.getCurrentIteration())
+ " training updates in total) ");
} catch (Throwable e) {
throw new HiveException("Exception caused in the iterative training", e);
} finally {
// delete the temporary file and release resources
try {
} catch (IOException e) {
throw new HiveException(
"Failed to close a file: " + dst.getFile().getAbsolutePath(), e);
this._inputBuf = null;
this._fileIO = null;
private void replayTrain(@Nonnull final ByteBuffer buf) {
final int itemI = buf.getInt();
final int knnSize = buf.getInt();
final Int2ObjectMap knnItems = new Int2ObjectOpenHashMap<>(1024);
final IntSet pairItems = new IntOpenHashSet();
for (int i = 0; i < knnSize; i++) {
int user = buf.getInt();
int ruSize = buf.getInt();
Int2FloatMap ru = new Int2FloatOpenHashMap(ruSize);
for (int j = 0; j < ruSize; j++) {
int itemK = buf.getInt();
float ruk = buf.getFloat();
ru.put(itemK, ruk);
knnItems.put(user, ru);
for (int itemJ : pairItems) {
train(itemI, knnItems, itemJ);
private void forwardModel() throws HiveException {
final IntWritable f0 = new IntWritable(); // i
final IntWritable f1 = new IntWritable(); // nn
final FloatWritable f2 = new FloatWritable(); // w
final Object[] forwardObj = new Object[] {f0, f1, f2};
final MutableObject catched = new MutableObject<>();
_weightMatrix.eachNonZeroCell(new VectorProcedure() {
public void apply(int i, int j, float value) {
if (value == 0.f) {
try {
} catch (HiveException e) {
HiveException ex = catched.get();
if (ex != null) {
throw ex;
}"Forwarded SLIM's weights matrix");
private static Int2ObjectMap kNNentries(@Nonnull final Object kNNiObj,
@Nonnull final MapObjectInspector knnItemsOI,
@Nonnull final PrimitiveObjectInspector knnItemsKeyOI,
@Nonnull final MapObjectInspector knnItemsValueOI,
@Nonnull final PrimitiveObjectInspector knnItemsValueKeyOI,
@Nonnull final PrimitiveObjectInspector knnItemsValueValueOI,
@Nullable Int2ObjectMap knnItems, @Nonnull final MutableInt nnzKNNi) {
if (knnItems == null) {
knnItems = new Int2ObjectOpenHashMap<>(1024);
} else {
int numElementOfKNNItems = 0;
for (Map.Entry, ?> entry : knnItemsOI.getMap(kNNiObj).entrySet()) {
int user = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), knnItemsKeyOI);
Int2FloatMap ru = int2floatMap(knnItemsValueOI.getMap(entry.getValue()),
knnItemsValueKeyOI, knnItemsValueValueOI);
knnItems.put(user, ru);
numElementOfKNNItems += ru.size();
return knnItems;
private static Int2FloatMap int2floatMap(@Nonnull final Map, ?> map,
@Nonnull final PrimitiveObjectInspector keyOI,
@Nonnull final PrimitiveObjectInspector valueOI) {
final Int2FloatMap result = new Int2FloatOpenHashMap(map.size());
for (Map.Entry, ?> entry : map.entrySet()) {
float v = PrimitiveObjectInspectorUtils.getFloat(entry.getValue(), valueOI);
if (v == 0.f) {
int k = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), keyOI);
result.put(k, v);
return result;
private static Int2FloatMap int2floatMap(final int item, @Nonnull final Map, ?> map,
@Nonnull final PrimitiveObjectInspector keyOI,
@Nonnull final PrimitiveObjectInspector valueOI,
@Nullable final FloatMatrix dataMatrix) {
return int2floatMap(item, map, keyOI, valueOI, dataMatrix, null);
private static Int2FloatMap int2floatMap(final int item, @Nonnull final Map, ?> map,
@Nonnull final PrimitiveObjectInspector keyOI,
@Nonnull final PrimitiveObjectInspector valueOI, @Nullable final FloatMatrix dataMatrix,
@Nullable Int2FloatMap dst) {
if (dst == null) {
dst = new Int2FloatOpenHashMap(map.size());
} else {
for (Map.Entry, ?> entry : map.entrySet()) {
float rating = PrimitiveObjectInspectorUtils.getFloat(entry.getValue(), valueOI);
if (rating == 0.f) {
int user = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), keyOI);
dst.put(user, rating);
if (dataMatrix != null) {
dataMatrix.set(item, user, rating);
return dst;
