All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.math.hadoop.DistributedRowMatrix Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.math.hadoop;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Iterator;

import org.apache.hadoop.conf.Configurable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapreduce.Job;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
import org.apache.mahout.math.CardinalityException;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorIterable;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.Function;
import com.google.common.collect.Iterators;

/**
 * DistributedRowMatrix is a FileSystem-backed VectorIterable in which the vectors live in a
 * SequenceFile, and distributed operations are executed as M/R passes on
 * Hadoop.  The usage is as follows: 

*

*

 *   // the path must already contain an already created SequenceFile!
 *   DistributedRowMatrix m = new DistributedRowMatrix("path/to/vector/sequenceFile", "tmp/path", 10000000, 250000);
 *   m.setConf(new Configuration());
 *   // now if we want to multiply a vector by this matrix, it's dimension must equal the row dimension of this
 *   // matrix.  If we want to timesSquared() a vector by this matrix, its dimension must equal the column dimension
 *   // of the matrix.
 *   Vector v = new DenseVector(250000);
 *   // now the following operation will be done via a M/R pass via Hadoop.
 *   Vector w = m.timesSquared(v);
 * 
* */ public class DistributedRowMatrix implements VectorIterable, Configurable { public static final String KEEP_TEMP_FILES = "DistributedMatrix.keep.temp.files"; private static final Logger log = LoggerFactory.getLogger(DistributedRowMatrix.class); private final Path inputPath; private final Path outputTmpPath; private Configuration conf; private Path rowPath; private Path outputTmpBasePath; private final int numRows; private final int numCols; private boolean keepTempFiles; public DistributedRowMatrix(Path inputPath, Path outputTmpPath, int numRows, int numCols) { this(inputPath, outputTmpPath, numRows, numCols, false); } public DistributedRowMatrix(Path inputPath, Path outputTmpPath, int numRows, int numCols, boolean keepTempFiles) { this.inputPath = inputPath; this.outputTmpPath = outputTmpPath; this.numRows = numRows; this.numCols = numCols; this.keepTempFiles = keepTempFiles; } @Override public Configuration getConf() { return conf; } @Override public void setConf(Configuration conf) { this.conf = conf; try { FileSystem fs = FileSystem.get(inputPath.toUri(), conf); rowPath = fs.makeQualified(inputPath); outputTmpBasePath = fs.makeQualified(outputTmpPath); keepTempFiles = conf.getBoolean(KEEP_TEMP_FILES, false); } catch (IOException ioe) { throw new IllegalStateException(ioe); } } public Path getRowPath() { return rowPath; } public Path getOutputTempPath() { return outputTmpBasePath; } public void setOutputTempPathString(String outPathString) { try { outputTmpBasePath = FileSystem.get(conf).makeQualified(new Path(outPathString)); } catch (IOException ioe) { log.warn("Unable to set outputBasePath to {}, leaving as {}", outPathString, outputTmpBasePath); } } @Override public Iterator iterateAll() { try { Path pathPattern = rowPath; if (FileSystem.get(conf).getFileStatus(rowPath).isDir()) { pathPattern = new Path(rowPath, "*"); } return Iterators.transform( new SequenceFileDirIterator(pathPattern, PathType.GLOB, PathFilters.logsCRCFilter(), null, true, conf), new Function,MatrixSlice>() { @Override public MatrixSlice apply(Pair from) { return new MatrixSlice(from.getSecond().get(), from.getFirst().get()); } }); } catch (IOException ioe) { throw new IllegalStateException(ioe); } } @Override public int numSlices() { return numRows(); } @Override public int numRows() { return numRows; } @Override public int numCols() { return numCols; } /** * This implements matrix this.transpose().times(other) * @param other a DistributedRowMatrix * @return a DistributedRowMatrix containing the product */ public DistributedRowMatrix times(DistributedRowMatrix other) throws IOException { return times(other, new Path(outputTmpBasePath.getParent(), "productWith-" + (System.nanoTime() & 0xFF))); } /** * This implements matrix this.transpose().times(other) * @param other a DistributedRowMatrix * @param outPath path to write result to * @return a DistributedRowMatrix containing the product */ public DistributedRowMatrix times(DistributedRowMatrix other, Path outPath) throws IOException { if (numRows != other.numRows()) { throw new CardinalityException(numRows, other.numRows()); } Configuration initialConf = getConf() == null ? new Configuration() : getConf(); Configuration conf = MatrixMultiplicationJob.createMatrixMultiplyJobConf(initialConf, rowPath, other.rowPath, outPath, other.numCols); JobClient.runJob(new JobConf(conf)); DistributedRowMatrix out = new DistributedRowMatrix(outPath, outputTmpPath, numCols, other.numCols()); out.setConf(conf); return out; } public Vector columnMeans() throws IOException { return columnMeans("SequentialAccessSparseVector"); } /** * Returns the column-wise mean of a DistributedRowMatrix * * @param vectorClass * desired class for the column-wise mean vector e.g. * RandomAccessSparseVector, DenseVector * @return Vector containing the column-wise mean of this */ public Vector columnMeans(String vectorClass) throws IOException { Path outputVectorTmpPath = new Path(outputTmpBasePath, new Path(Long.toString(System.nanoTime()))); Configuration initialConf = getConf() == null ? new Configuration() : getConf(); String vectorClassFull = "org.apache.mahout.math." + vectorClass; Vector mean = MatrixColumnMeansJob.run(initialConf, rowPath, outputVectorTmpPath, vectorClassFull); if (!keepTempFiles) { FileSystem fs = outputVectorTmpPath.getFileSystem(conf); fs.delete(outputVectorTmpPath, true); } return mean; } public DistributedRowMatrix transpose() throws IOException { Path outputPath = new Path(rowPath.getParent(), "transpose-" + (System.nanoTime() & 0xFF)); Configuration initialConf = getConf() == null ? new Configuration() : getConf(); Job transposeJob = TransposeJob.buildTransposeJob(initialConf, rowPath, outputPath, numRows); try { transposeJob.waitForCompletion(true); } catch (Exception e) { throw new IllegalStateException("transposition failed", e); } DistributedRowMatrix m = new DistributedRowMatrix(outputPath, outputTmpPath, numCols, numRows); m.setConf(this.conf); return m; } @Override public Vector times(Vector v) { try { Configuration initialConf = getConf() == null ? new Configuration() : getConf(); Path outputVectorTmpPath = new Path(outputTmpBasePath, new Path(Long.toString(System.nanoTime()))); Job job = TimesSquaredJob.createTimesJob(initialConf, v, numRows, rowPath, outputVectorTmpPath); try { job.waitForCompletion(true); } catch (Exception e) { throw new IllegalStateException("times failed", e); } Vector result = TimesSquaredJob.retrieveTimesSquaredOutputVector(outputVectorTmpPath, conf); if (!keepTempFiles) { FileSystem fs = outputVectorTmpPath.getFileSystem(conf); fs.delete(outputVectorTmpPath, true); } return result; } catch (IOException ioe) { throw new IllegalStateException(ioe); } } @Override public Vector timesSquared(Vector v) { try { Configuration initialConf = getConf() == null ? new Configuration() : getConf(); Path outputVectorTmpPath = new Path(outputTmpBasePath, new Path(Long.toString(System.nanoTime()))); Job job = TimesSquaredJob.createTimesSquaredJob(initialConf, v, rowPath, outputVectorTmpPath); try { job.waitForCompletion(true); } catch (Exception e) { throw new IllegalStateException("timesSquared failed", e); } Vector result = TimesSquaredJob.retrieveTimesSquaredOutputVector(outputVectorTmpPath, conf); if (!keepTempFiles) { FileSystem fs = outputVectorTmpPath.getFileSystem(conf); fs.delete(outputVectorTmpPath, true); } return result; } catch (IOException ioe) { throw new IllegalStateException(ioe); } } @Override public Iterator iterator() { return iterateAll(); } public static class MatrixEntryWritable implements WritableComparable { private int row; private int col; private double val; public int getRow() { return row; } public void setRow(int row) { this.row = row; } public int getCol() { return col; } public void setCol(int col) { this.col = col; } public double getVal() { return val; } public void setVal(double val) { this.val = val; } @Override public int compareTo(MatrixEntryWritable o) { if (row > o.row) { return 1; } else if (row < o.row) { return -1; } else { if (col > o.col) { return 1; } else if (col < o.col) { return -1; } else { return 0; } } } @Override public boolean equals(Object o) { if (!(o instanceof MatrixEntryWritable)) { return false; } MatrixEntryWritable other = (MatrixEntryWritable) o; return row == other.row && col == other.col; } @Override public int hashCode() { return row + 31 * col; } @Override public void write(DataOutput out) throws IOException { out.writeInt(row); out.writeInt(col); out.writeDouble(val); } @Override public void readFields(DataInput in) throws IOException { row = in.readInt(); col = in.readInt(); val = in.readDouble(); } @Override public String toString() { return "(" + row + ',' + col + "):" + val; } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy