hivemall.recommend.SlimUDTF Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.recommend;
import hivemall.UDTFWithOptions;
import hivemall.annotations.VisibleForTesting;
import hivemall.common.ConversionState;
import matrix4j.matrix.FloatMatrix;
import matrix4j.matrix.sparse.floats.DoKFloatMatrix;
import matrix4j.vector.VectorProcedure;
import hivemall.utils.collections.Fastutil;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.FileUtils;
import hivemall.utils.io.NioStatefulSegment;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.SizeOf;
import hivemall.utils.lang.mutable.MutableDouble;
import hivemall.utils.lang.mutable.MutableInt;
import hivemall.utils.lang.mutable.MutableObject;
import it.unimi.dsi.fastutil.ints.Int2FloatMap;
import it.unimi.dsi.fastutil.ints.Int2FloatOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapred.Counters;
import org.apache.hadoop.mapred.Reporter;
/**
* Sparse Linear Methods (SLIM) for Top-N Recommender Systems.
*
*
* Xia Ning and George Karypis, SLIM: Sparse Linear Methods for Top-N Recommender Systems, Proc. ICDM, 2011.
*
*/
@Description(name = "train_slim",
value = "_FUNC_( int i, map r_i, map> topKRatesOfI, int j, map r_j [, constant string options]) "
+ "- Returns row index, column index and non-zero weight value of prediction model")
public class SlimUDTF extends UDTFWithOptions {
private static final Log logger = LogFactory.getLog(SlimUDTF.class);
//--------------------------------------------
// input OIs
private PrimitiveObjectInspector itemIOI;
private PrimitiveObjectInspector itemJOI;
private MapObjectInspector riOI;
private MapObjectInspector rjOI;
private MapObjectInspector knnItemsOI;
private PrimitiveObjectInspector knnItemsKeyOI;
private MapObjectInspector knnItemsValueOI;
private PrimitiveObjectInspector knnItemsValueKeyOI;
private PrimitiveObjectInspector knnItemsValueValueOI;
private PrimitiveObjectInspector riKeyOI;
private PrimitiveObjectInspector riValueOI;
private PrimitiveObjectInspector rjKeyOI;
private PrimitiveObjectInspector rjValueOI;
//--------------------------------------------
// hyperparameters
private double l1;
private double l2;
private int numIterations;
//--------------------------------------------
// model parameters and else
/** item-item weight matrix */
private transient DoKFloatMatrix _weightMatrix;
//--------------------------------------------
// caching for each item i
private int _previousItemId;
@Nullable
private transient Int2FloatMap _ri;
@Nullable
private transient Int2ObjectMap _kNNi;
/** The number of elements in kNNi */
@Nullable
private transient MutableInt _nnzKNNi;
//--------------------------------------------
// variables for iteration supports
/** item-user matrix holding the input data */
@Nullable
private transient FloatMatrix _dataMatrix;
// used to store KNN data into temporary file for iterative training
private transient NioStatefulSegment _fileIO;
private transient ByteBuffer _inputBuf;
private ConversionState _cvState;
private long _observedTrainingExamples;
//--------------------------------------------
public SlimUDTF() {}
@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
final int numArgs = argOIs.length;
if (numArgs == 1 && HiveUtils.isConstString(argOIs[0])) {// for -help option
String rawArgs = HiveUtils.getConstString(argOIs[0]);
parseOptions(rawArgs);
}
if (numArgs != 5 && numArgs != 6) {
throw new UDFArgumentException(
"_FUNC_ takes 5 or 6 arguments: int i, map r_i, map> topKRatesOfI, int j, map r_j [, constant string options]: "
+ Arrays.toString(argOIs));
}
this.itemIOI = HiveUtils.asIntCompatibleOI(argOIs[0]);
this.riOI = HiveUtils.asMapOI(argOIs[1]);
this.riKeyOI = HiveUtils.asIntCompatibleOI((riOI.getMapKeyObjectInspector()));
this.riValueOI = HiveUtils.asPrimitiveObjectInspector((riOI.getMapValueObjectInspector()));
this.knnItemsOI = HiveUtils.asMapOI(argOIs[2]);
this.knnItemsKeyOI = HiveUtils.asIntCompatibleOI(knnItemsOI.getMapKeyObjectInspector());
this.knnItemsValueOI = HiveUtils.asMapOI(knnItemsOI.getMapValueObjectInspector());
this.knnItemsValueKeyOI =
HiveUtils.asIntCompatibleOI(knnItemsValueOI.getMapKeyObjectInspector());
this.knnItemsValueValueOI =
HiveUtils.asDoubleCompatibleOI(knnItemsValueOI.getMapValueObjectInspector());
this.itemJOI = HiveUtils.asIntCompatibleOI(argOIs[3]);
this.rjOI = HiveUtils.asMapOI(argOIs[4]);
this.rjKeyOI = HiveUtils.asIntCompatibleOI((rjOI.getMapKeyObjectInspector()));
this.rjValueOI = HiveUtils.asPrimitiveObjectInspector((rjOI.getMapValueObjectInspector()));
processOptions(argOIs);
this._observedTrainingExamples = 0L;
this._previousItemId = Integer.MIN_VALUE;
this._weightMatrix = null;
this._dataMatrix = null;
List fieldNames = new ArrayList<>();
List fieldOIs = new ArrayList<>();
fieldNames.add("j");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("nn");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("w");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
@Override
protected Options getOptions() {
Options opts = new Options();
opts.addOption("l1", "l1coefficient", true,
"Coefficient for l1 regularizer [default: 0.001]");
opts.addOption("l2", "l2coefficient", true,
"Coefficient for l2 regularizer [default: 0.0005]");
opts.addOption("iters", "iterations", true,
"The number of iterations for coordinate descent [default: 30]");
opts.addOption("disable_cv", "disable_cvtest", false,
"Whether to disable convergence check [default: enabled]");
opts.addOption("cv_rate", "convergence_rate", true,
"Threshold to determine convergence [default: 0.005]");
return opts;
}
@Override
protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs)
throws UDFArgumentException {
CommandLine cl = null;
double l1 = 0.001d;
double l2 = 0.0005d;
int numIterations = 30;
boolean conversionCheck = true;
double cv_rate = 0.005d;
if (argOIs.length >= 6) {
String rawArgs = HiveUtils.getConstString(argOIs[5]);
cl = parseOptions(rawArgs);
l1 = Primitives.parseDouble(cl.getOptionValue("l1"), l1);
if (l1 < 0.d) {
throw new UDFArgumentException("Argument `double l1` must be non-negative: " + l1);
}
l2 = Primitives.parseDouble(cl.getOptionValue("l2"), l2);
if (l2 < 0.d) {
throw new UDFArgumentException("Argument `double l2` must be non-negative: " + l2);
}
numIterations = Primitives.parseInt(cl.getOptionValue("iters"), numIterations);
if (numIterations <= 0) {
throw new UDFArgumentException(
"Argument `int iters` must be greater than 0: " + numIterations);
}
conversionCheck = !cl.hasOption("disable_cvtest");
cv_rate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), cv_rate);
if (cv_rate <= 0) {
throw new UDFArgumentException(
"Argument `double cv_rate` must be greater than 0.0: " + cv_rate);
}
}
this.l1 = l1;
this.l2 = l2;
this.numIterations = numIterations;
this._cvState = new ConversionState(conversionCheck, cv_rate);
return cl;
}
@Override
public void process(@Nonnull Object[] args) throws HiveException {
if (_weightMatrix == null) {// initialize variables
this._weightMatrix = new DoKFloatMatrix();
if (numIterations >= 2) {
this._dataMatrix = new DoKFloatMatrix();
}
this._nnzKNNi = new MutableInt();
}
final int itemI = PrimitiveObjectInspectorUtils.getInt(args[0], itemIOI);
if (itemI != _previousItemId || _ri == null) {
// cache Ri and kNNi
this._ri =
int2floatMap(itemI, riOI.getMap(args[1]), riKeyOI, riValueOI, _dataMatrix, _ri);
this._kNNi = kNNentries(args[2], knnItemsOI, knnItemsKeyOI, knnItemsValueOI,
knnItemsValueKeyOI, knnItemsValueValueOI, _kNNi, _nnzKNNi);
final int numKNNItems = _nnzKNNi.getValue();
if (numIterations >= 2 && numKNNItems >= 1) {
recordTrainingInput(itemI, _kNNi, numKNNItems);
}
this._previousItemId = itemI;
}
int itemJ = PrimitiveObjectInspectorUtils.getInt(args[3], itemJOI);
Int2FloatMap rj =
int2floatMap(itemJ, rjOI.getMap(args[4]), rjKeyOI, rjValueOI, _dataMatrix);
train(itemI, _ri, _kNNi, itemJ, rj);
_observedTrainingExamples++;
}
private void recordTrainingInput(final int itemI,
@Nonnull final Int2ObjectMap knnItems, final int numKNNItems)
throws HiveException {
ByteBuffer buf = this._inputBuf;
NioStatefulSegment dst = this._fileIO;
if (buf == null) {
// invoke only at task node (initialize is also invoked in compilation)
final File file;
try {
file = File.createTempFile("hivemall_slim", ".sgmt"); // to save KNN data
file.deleteOnExit();
if (!file.canWrite()) {
throw new UDFArgumentException(
"Cannot write a temporary file: " + file.getAbsolutePath());
}
} catch (IOException ioe) {
throw new UDFArgumentException(ioe);
}
this._inputBuf = buf = ByteBuffer.allocateDirect(8 * 1024 * 1024); // 8MB
this._fileIO = dst = new NioStatefulSegment(file, false);
}
int recordBytes = SizeOf.INT + SizeOf.INT + SizeOf.INT * 2 * knnItems.size()
+ (SizeOf.INT + SizeOf.FLOAT) * numKNNItems;
int requiredBytes = SizeOf.INT + recordBytes; // need to allocate space for "recordBytes" itself
int remain = buf.remaining();
if (remain < requiredBytes) {
writeBuffer(buf, dst);
}
buf.putInt(recordBytes);
buf.putInt(itemI);
buf.putInt(knnItems.size());
for (Int2ObjectMap.Entry e1 : Fastutil.fastIterable(knnItems)) {
int user = e1.getIntKey();
buf.putInt(user);
Int2FloatMap ru = e1.getValue();
buf.putInt(ru.size());
for (Int2FloatMap.Entry e2 : Fastutil.fastIterable(ru)) {
buf.putInt(e2.getIntKey());
buf.putFloat(e2.getFloatValue());
}
}
}
private static void writeBuffer(@Nonnull final ByteBuffer srcBuf,
@Nonnull final NioStatefulSegment dst) throws HiveException {
srcBuf.flip();
try {
dst.write(srcBuf);
} catch (IOException e) {
throw new HiveException("Exception causes while writing a buffer to file", e);
}
srcBuf.clear();
}
private void train(final int itemI, @Nonnull final Int2FloatMap ri,
@Nonnull final Int2ObjectMap kNNi, final int itemJ,
@Nonnull final Int2FloatMap rj) {
final FloatMatrix W = _weightMatrix;
final int N = rj.size();
if (N == 0) {
return;
}
double gradSum = 0.d;
double rateSum = 0.d;
double lossSum = 0.d;
for (Int2FloatMap.Entry e : Fastutil.fastIterable(rj)) {
int user = e.getIntKey();
double ruj = e.getFloatValue();
double rui = ri.get(user); // ri.getOrDefault(user, 0.f);
double eui = rui - predict(user, itemI, kNNi, itemJ, W);
gradSum += ruj * eui;
rateSum += ruj * ruj;
lossSum += eui * eui;
}
gradSum /= N;
rateSum /= N;
double wij = W.get(itemI, itemJ, 0.d);
double loss = lossSum / N + 0.5d * l2 * wij * wij + l1 * wij;
_cvState.incrLoss(loss);
W.set(itemI, itemJ, getUpdateTerm(gradSum, rateSum, l1, l2));
}
private void train(final int itemI, @Nonnull final Int2ObjectMap knnItems,
final int itemJ) {
final FloatMatrix A = _dataMatrix;
final FloatMatrix W = _weightMatrix;
final int N = A.numColumns(itemJ);
if (N == 0) {
return;
}
final MutableDouble mutableGradSum = new MutableDouble(0.d);
final MutableDouble mutableRateSum = new MutableDouble(0.d);
final MutableDouble mutableLossSum = new MutableDouble(0.d);
A.eachNonZeroInRow(itemJ, new VectorProcedure() {
@Override
public void apply(int user, double ruj) {
double rui = A.get(itemI, user, 0.d);
double eui = rui - predict(user, itemI, knnItems, itemJ, W);
mutableGradSum.addValue(ruj * eui);
mutableRateSum.addValue(ruj * ruj);
mutableLossSum.addValue(eui * eui);
}
});
double gradSum = mutableGradSum.getValue() / N;
double rateSum = mutableRateSum.getValue() / N;
double wij = W.get(itemI, itemJ, 0.d);
double loss = mutableLossSum.getValue() / N + 0.5 * l2 * wij * wij + l1 * wij;
_cvState.incrLoss(loss);
W.set(itemI, itemJ, getUpdateTerm(gradSum, rateSum, l1, l2));
}
private static double predict(final int user, final int itemI,
@Nonnull final Int2ObjectMap knnItems, final int excludeIndex,
@Nonnull final FloatMatrix weightMatrix) {
final Int2FloatMap kNNu = knnItems.get(user);
if (kNNu == null) {
return 0.d;
}
double pred = 0.d;
for (Int2FloatMap.Entry e : Fastutil.fastIterable(kNNu)) {
final int itemK = e.getIntKey();
if (itemK == excludeIndex) {
continue;
}
float ruk = e.getFloatValue();
pred += ruk * weightMatrix.get(itemI, itemK, 0.d);
}
return pred;
}
private static double getUpdateTerm(final double gradSum, final double rateSum, final double l1,
final double l2) {
double update = 0.d;
if (Math.abs(gradSum) > l1) {
if (gradSum > 0.d) {
update = (gradSum - l1) / (rateSum + l2);
} else {
update = (gradSum + l1) / (rateSum + l2);
}
// non-negative constraints
if (update < 0.d) {
update = 0.d;
}
}
return update;
}
@Override
public void close() throws HiveException {
finalizeTraining();
forwardModel();
this._weightMatrix = null;
}
@VisibleForTesting
void finalizeTraining() throws HiveException {
if (numIterations > 1) {
this._ri = null;
this._kNNi = null;
runIterativeTraining();
this._dataMatrix = null;
}
}
private void runIterativeTraining() throws HiveException {
final ByteBuffer buf = this._inputBuf;
final NioStatefulSegment dst = this._fileIO;
assert (buf != null);
assert (dst != null);
final Reporter reporter = getReporter();
final Counters.Counter iterCounter = (reporter == null) ? null
: reporter.getCounter("hivemall.recommend.slim$Counter", "iteration");
try {
if (dst.getPosition() == 0L) {// run iterations w/o temporary file
if (buf.position() == 0) {
return; // no training example
}
buf.flip();
for (int iter = 2; iter < numIterations; iter++) {
_cvState.next();
reportProgress(reporter);
setCounterValue(iterCounter, iter);
while (buf.remaining() > 0) {
int recordBytes = buf.getInt();
assert (recordBytes > 0) : recordBytes;
replayTrain(buf);
}
buf.rewind();
if (_cvState.isConverged(_observedTrainingExamples)) {
break;
}
}
logger.info("Performed " + _cvState.getCurrentIteration() + " iterations of "
+ NumberUtils.formatNumber(_observedTrainingExamples)
+ " training examples on memory (thus "
+ NumberUtils.formatNumber(
_observedTrainingExamples * _cvState.getCurrentIteration())
+ " training updates in total) ");
} else { // read training examples in the temporary file and invoke train for each example
// write KNNi in buffer to a temporary file
if (buf.remaining() > 0) {
writeBuffer(buf, dst);
}
try {
dst.flush();
} catch (IOException e) {
throw new HiveException(
"Failed to flush a file: " + dst.getFile().getAbsolutePath(), e);
}
if (logger.isInfoEnabled()) {
File tmpFile = dst.getFile();
logger.info(
"Wrote KNN entries of axis items to a temporary file for iterative training: "
+ tmpFile.getAbsolutePath() + " ("
+ FileUtils.prettyFileSize(tmpFile) + ")");
}
// run iterations
for (int iter = 2; iter < numIterations; iter++) {
_cvState.next();
setCounterValue(iterCounter, iter);
buf.clear();
dst.resetPosition();
while (true) {
reportProgress(reporter);
// load a KNNi to a buffer in the temporary file
final int bytesRead;
try {
bytesRead = dst.read(buf);
} catch (IOException e) {
throw new HiveException(
"Failed to read a file: " + dst.getFile().getAbsolutePath(), e);
}
if (bytesRead == 0) { // reached file EOF
break;
}
assert (bytesRead > 0) : bytesRead;
// reads training examples from a buffer
buf.flip();
int remain = buf.remaining();
if (remain < SizeOf.INT) {
throw new HiveException("Illegal file format was detected");
}
while (remain >= SizeOf.INT) {
int pos = buf.position();
int recordBytes = buf.getInt();
remain -= SizeOf.INT;
if (remain < recordBytes) {
buf.position(pos);
break;
}
replayTrain(buf);
remain -= recordBytes;
}
buf.compact();
}
if (_cvState.isConverged(_observedTrainingExamples)) {
break;
}
}
logger.info("Performed " + _cvState.getCurrentIteration() + " iterations of "
+ NumberUtils.formatNumber(_observedTrainingExamples)
+ " training examples on memory and KNNi data on secondary storage (thus "
+ NumberUtils.formatNumber(
_observedTrainingExamples * _cvState.getCurrentIteration())
+ " training updates in total) ");
}
} catch (Throwable e) {
throw new HiveException("Exception caused in the iterative training", e);
} finally {
// delete the temporary file and release resources
try {
dst.close(true);
} catch (IOException e) {
throw new HiveException(
"Failed to close a file: " + dst.getFile().getAbsolutePath(), e);
}
this._inputBuf = null;
this._fileIO = null;
}
}
private void replayTrain(@Nonnull final ByteBuffer buf) {
final int itemI = buf.getInt();
final int knnSize = buf.getInt();
final Int2ObjectMap knnItems = new Int2ObjectOpenHashMap<>(1024);
final IntSet pairItems = new IntOpenHashSet();
for (int i = 0; i < knnSize; i++) {
int user = buf.getInt();
int ruSize = buf.getInt();
Int2FloatMap ru = new Int2FloatOpenHashMap(ruSize);
ru.defaultReturnValue(0.f);
for (int j = 0; j < ruSize; j++) {
int itemK = buf.getInt();
pairItems.add(itemK);
float ruk = buf.getFloat();
ru.put(itemK, ruk);
}
knnItems.put(user, ru);
}
for (int itemJ : pairItems) {
train(itemI, knnItems, itemJ);
}
}
private void forwardModel() throws HiveException {
final IntWritable f0 = new IntWritable(); // i
final IntWritable f1 = new IntWritable(); // nn
final FloatWritable f2 = new FloatWritable(); // w
final Object[] forwardObj = new Object[] {f0, f1, f2};
final MutableObject catched = new MutableObject<>();
_weightMatrix.eachNonZeroCell(new VectorProcedure() {
@Override
public void apply(int i, int j, float value) {
if (value == 0.f) {
return;
}
f0.set(i);
f1.set(j);
f2.set(value);
try {
forward(forwardObj);
} catch (HiveException e) {
catched.setIfAbsent(e);
}
}
});
HiveException ex = catched.get();
if (ex != null) {
throw ex;
}
logger.info("Forwarded SLIM's weights matrix");
}
@Nonnull
private static Int2ObjectMap kNNentries(@Nonnull final Object kNNiObj,
@Nonnull final MapObjectInspector knnItemsOI,
@Nonnull final PrimitiveObjectInspector knnItemsKeyOI,
@Nonnull final MapObjectInspector knnItemsValueOI,
@Nonnull final PrimitiveObjectInspector knnItemsValueKeyOI,
@Nonnull final PrimitiveObjectInspector knnItemsValueValueOI,
@Nullable Int2ObjectMap knnItems, @Nonnull final MutableInt nnzKNNi) {
if (knnItems == null) {
knnItems = new Int2ObjectOpenHashMap<>(1024);
} else {
knnItems.clear();
}
int numElementOfKNNItems = 0;
for (Map.Entry, ?> entry : knnItemsOI.getMap(kNNiObj).entrySet()) {
int user = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), knnItemsKeyOI);
Int2FloatMap ru = int2floatMap(knnItemsValueOI.getMap(entry.getValue()),
knnItemsValueKeyOI, knnItemsValueValueOI);
knnItems.put(user, ru);
numElementOfKNNItems += ru.size();
}
nnzKNNi.setValue(numElementOfKNNItems);
return knnItems;
}
@Nonnull
private static Int2FloatMap int2floatMap(@Nonnull final Map, ?> map,
@Nonnull final PrimitiveObjectInspector keyOI,
@Nonnull final PrimitiveObjectInspector valueOI) {
final Int2FloatMap result = new Int2FloatOpenHashMap(map.size());
result.defaultReturnValue(0.f);
for (Map.Entry, ?> entry : map.entrySet()) {
float v = PrimitiveObjectInspectorUtils.getFloat(entry.getValue(), valueOI);
if (v == 0.f) {
continue;
}
int k = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), keyOI);
result.put(k, v);
}
return result;
}
@Nonnull
private static Int2FloatMap int2floatMap(final int item, @Nonnull final Map, ?> map,
@Nonnull final PrimitiveObjectInspector keyOI,
@Nonnull final PrimitiveObjectInspector valueOI,
@Nullable final FloatMatrix dataMatrix) {
return int2floatMap(item, map, keyOI, valueOI, dataMatrix, null);
}
@Nonnull
private static Int2FloatMap int2floatMap(final int item, @Nonnull final Map, ?> map,
@Nonnull final PrimitiveObjectInspector keyOI,
@Nonnull final PrimitiveObjectInspector valueOI, @Nullable final FloatMatrix dataMatrix,
@Nullable Int2FloatMap dst) {
if (dst == null) {
dst = new Int2FloatOpenHashMap(map.size());
dst.defaultReturnValue(0.f);
} else {
dst.clear();
}
for (Map.Entry, ?> entry : map.entrySet()) {
float rating = PrimitiveObjectInspectorUtils.getFloat(entry.getValue(), valueOI);
if (rating == 0.f) {
continue;
}
int user = PrimitiveObjectInspectorUtils.getInt(entry.getKey(), keyOI);
dst.put(user, rating);
if (dataMatrix != null) {
dataMatrix.set(item, user, rating);
}
}
return dst;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy