org.apache.mahout.math.als.AlternatingLeastSquaresSolver Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mahout-math Show documentation
Show all versions of mahout-math Show documentation
High performance scientific and technical computing data structures and methods,
mostly based on CERN's
Colt Java API
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math.als;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.QRDecomposition;
import org.apache.mahout.math.Vector;
/**
* See
*
* this paper.
*/
public final class AlternatingLeastSquaresSolver {
private AlternatingLeastSquaresSolver() {}
//TODO make feature vectors a simple array
public static Vector solve(Iterable featureVectors, Vector ratingVector, double lambda, int numFeatures) {
Preconditions.checkNotNull(featureVectors, "Feature Vectors cannot be null");
Preconditions.checkArgument(!Iterables.isEmpty(featureVectors));
Preconditions.checkNotNull(ratingVector, "Rating Vector cannot be null");
Preconditions.checkArgument(ratingVector.getNumNondefaultElements() > 0, "Rating Vector cannot be empty");
Preconditions.checkArgument(Iterables.size(featureVectors) == ratingVector.getNumNondefaultElements());
int nui = ratingVector.getNumNondefaultElements();
Matrix MiIi = createMiIi(featureVectors, numFeatures);
Matrix RiIiMaybeTransposed = createRiIiMaybeTransposed(ratingVector);
/* compute Ai = MiIi * t(MiIi) + lambda * nui * E */
Matrix Ai = miTimesMiTransposePlusLambdaTimesNuiTimesE(MiIi, lambda, nui);
/* compute Vi = MiIi * t(R(i,Ii)) */
Matrix Vi = MiIi.times(RiIiMaybeTransposed);
/* compute Ai * ui = Vi */
return solve(Ai, Vi);
}
private static Vector solve(Matrix Ai, Matrix Vi) {
return new QRDecomposition(Ai).solve(Vi).viewColumn(0);
}
static Matrix addLambdaTimesNuiTimesE(Matrix matrix, double lambda, int nui) {
Preconditions.checkArgument(matrix.numCols() == matrix.numRows(), "Must be a Square Matrix");
double lambdaTimesNui = lambda * nui;
int numCols = matrix.numCols();
for (int n = 0; n < numCols; n++) {
matrix.setQuick(n, n, matrix.getQuick(n, n) + lambdaTimesNui);
}
return matrix;
}
private static Matrix miTimesMiTransposePlusLambdaTimesNuiTimesE(Matrix MiIi, double lambda, int nui) {
double lambdaTimesNui = lambda * nui;
int rows = MiIi.numRows();
double[][] result = new double[rows][rows];
for (int i = 0; i < rows; i++) {
for (int j = i; j < rows; j++) {
double dot = MiIi.viewRow(i).dot(MiIi.viewRow(j));
if (i != j) {
result[i][j] = dot;
result[j][i] = dot;
} else {
result[i][i] = dot + lambdaTimesNui;
}
}
}
return new DenseMatrix(result, true);
}
static Matrix createMiIi(Iterable featureVectors, int numFeatures) {
double[][] MiIi = new double[numFeatures][Iterables.size(featureVectors)];
int n = 0;
for (Vector featureVector : featureVectors) {
for (int m = 0; m < numFeatures; m++) {
MiIi[m][n] = featureVector.getQuick(m);
}
n++;
}
return new DenseMatrix(MiIi, true);
}
static Matrix createRiIiMaybeTransposed(Vector ratingVector) {
Preconditions.checkArgument(ratingVector.isSequentialAccess(), "Ratings should be iterable in Index or Sequential Order");
double[][] RiIiMaybeTransposed = new double[ratingVector.getNumNondefaultElements()][1];
int index = 0;
for (Vector.Element elem : ratingVector.nonZeroes()) {
RiIiMaybeTransposed[index++][0] = elem.get();
}
return new DenseMatrix(RiIiMaybeTransposed, true);
}
}