org.apache.sysml.api.MLMatrix Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.api;
import java.io.IOException;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SQLContext.QueryExecution;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.types.StructType;
import scala.Tuple2;
import org.apache.sysml.parser.DMLTranslator;
import org.apache.sysml.parser.ParseException;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.spark.functions.GetMIMBFromRow;
import org.apache.sysml.runtime.instructions.spark.functions.GetMLBlock;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
/**
* Experimental API: Might be discontinued in future release
*
* This class serves four purposes:
* 1. It allows SystemML to fit nicely in MLPipeline by reducing number of reblocks.
* 2. It allows users to easily read and write matrices without worrying
* too much about format, metadata and type of underlying RDDs.
* 3. It provides mechanism to convert to and from MLLib's BlockedMatrix format
* 4. It provides off-the-shelf library for Distributed Blocked Matrix and reduces learning curve for using SystemML.
* However, it is important to know that it is easy to abuse this off-the-shelf library and think it as replacement
* to writing DML, which it is not. It does not provide any optimization between calls. A simple example
* of the optimization that is conveniently skipped is: (t(m) %*% m)).
* Also, note that this library is not thread-safe. The operator precedence is not exactly same as DML (as the precedence is
* enforced by scala compiler), so please use appropriate brackets to enforce precedence.
import org.apache.sysml.api.{MLContext, MLMatrix}
val ml = new MLContext(sc)
val mat1 = ml.read(sqlContext, "V_small.csv", "csv")
val mat2 = ml.read(sqlContext, "W_small.mtx", "binary")
val result = mat1.transpose() %*% mat2
result.write("Result_small.mtx", "text")
*/
public class MLMatrix extends DataFrame {
private static final long serialVersionUID = -7005940673916671165L;
protected static final Log LOG = LogFactory.getLog(DMLScript.class.getName());
protected MatrixCharacteristics mc = null;
protected MLContext ml = null;
protected MLMatrix(SQLContext sqlContext, LogicalPlan logicalPlan, MLContext ml) {
super(sqlContext, logicalPlan);
this.ml = ml;
}
protected MLMatrix(SQLContext sqlContext, QueryExecution queryExecution, MLContext ml) {
super(sqlContext, queryExecution);
this.ml = ml;
}
// Only used internally to set a new MLMatrix after one of matrix operations.
// Not to be used externally.
protected MLMatrix(DataFrame df, MatrixCharacteristics mc, MLContext ml) throws DMLRuntimeException {
super(df.sqlContext(), df.logicalPlan());
this.mc = mc;
this.ml = ml;
}
static String writeStmt = "write(output, \"tmp\", format=\"binary\", rows_in_block=" + DMLTranslator.DMLBlockSize + ", cols_in_block=" + DMLTranslator.DMLBlockSize + ");";
// ------------------------------------------------------------------------------------------------
// /**
// * Experimental unstable API: Converts our blocked matrix format to MLLib's format
// * @return
// */
// public BlockMatrix toBlockedMatrix() {
// JavaPairRDD blocks = getRDDLazily(this);
// RDD, Matrix>> mllibBlocks = blocks.mapToPair(new GetMLLibBlocks(mc.getRows(), mc.getCols(), mc.getRowsPerBlock(), mc.getColsPerBlock())).rdd();
// return new BlockMatrix(mllibBlocks, mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getRows(), mc.getCols());
// }
// ------------------------------------------------------------------------------------------------
static MLMatrix createMLMatrix(MLContext ml, SQLContext sqlContext, JavaPairRDD blocks, MatrixCharacteristics mc) throws DMLRuntimeException {
RDD rows = blocks.map(new GetMLBlock()).rdd();
StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
return new MLMatrix(sqlContext.createDataFrame(rows.toJavaRDD(), schema), mc, ml);
}
/**
* Convenient method to write a MLMatrix.
*/
public void write(String filePath, String format) throws IOException, DMLException, ParseException {
ml.reset();
ml.registerInput("left", this);
ml.executeScript("left = read(\"\"); output=left; write(output, \"" + filePath + "\", format=\"" + format + "\");");
}
private double getScalarBuiltinFunctionResult(String fn) throws IOException, DMLException, ParseException {
if(fn.compareTo("nrow") == 0 || fn.compareTo("ncol") == 0) {
ml.reset();
ml.registerInput("left", getRDDLazily(this), mc.getRows(), mc.getCols(), mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros());
ml.registerOutput("output");
String script = "left = read(\"\");"
+ "val = " + fn + "(left); "
+ "output = matrix(val, rows=1, cols=1); "
+ writeStmt;
MLOutput out = ml.executeScript(script);
List> result = out.getBinaryBlockedRDD("output").collect();
if(result == null || result.size() != 1) {
throw new DMLRuntimeException("Error while computing the function: " + fn);
}
return result.get(0)._2.getValue(0, 0);
}
else {
throw new DMLRuntimeException("The function " + fn + " is not yet supported in MLMatrix");
}
}
/**
* Gets or computes the number of rows.
* @return
* @throws ParseException
* @throws DMLException
* @throws IOException
*/
public long numRows() throws IOException, DMLException, ParseException {
if(mc.rowsKnown()) {
return mc.getRows();
}
else {
return (long) getScalarBuiltinFunctionResult("nrow");
}
}
/**
* Gets or computes the number of columns.
* @return
* @throws ParseException
* @throws DMLException
* @throws IOException
*/
public long numCols() throws IOException, DMLException, ParseException {
if(mc.colsKnown()) {
return mc.getCols();
}
else {
return (long) getScalarBuiltinFunctionResult("ncol");
}
}
public int rowsPerBlock() {
return mc.getRowsPerBlock();
}
public int colsPerBlock() {
return mc.getColsPerBlock();
}
private String getScript(String binaryOperator) {
return "left = read(\"\");"
+ "right = read(\"\");"
+ "output = left " + binaryOperator + " right; "
+ writeStmt;
}
private String getScalarBinaryScript(String binaryOperator, double scalar, boolean isScalarLeft) {
if(isScalarLeft) {
return "left = read(\"\");"
+ "output = " + scalar + " " + binaryOperator + " left ;"
+ writeStmt;
}
else {
return "left = read(\"\");"
+ "output = left " + binaryOperator + " " + scalar + ";"
+ writeStmt;
}
}
static JavaPairRDD getRDDLazily(MLMatrix mat) {
return mat.rdd().toJavaRDD().mapToPair(new GetMIMBFromRow());
}
private MLMatrix matrixBinaryOp(MLMatrix that, String op) throws IOException, DMLException, ParseException {
if(mc.getRowsPerBlock() != that.mc.getRowsPerBlock() || mc.getColsPerBlock() != that.mc.getColsPerBlock()) {
throw new DMLRuntimeException("Incompatible block sizes: brlen:" + mc.getRowsPerBlock() + "!=" + that.mc.getRowsPerBlock() + " || bclen:" + mc.getColsPerBlock() + "!=" + that.mc.getColsPerBlock());
}
if(op.compareTo("%*%") == 0) {
if(mc.getCols() != that.mc.getRows()) {
throw new DMLRuntimeException("Dimensions mismatch:" + mc.getCols() + "!=" + that.mc.getRows());
}
}
else {
if(mc.getRows() != that.mc.getRows() || mc.getCols() != that.mc.getCols()) {
throw new DMLRuntimeException("Dimensions mismatch:" + mc.getRows() + "!=" + that.mc.getRows() + " || " + mc.getCols() + "!=" + that.mc.getCols());
}
}
ml.reset();
ml.registerInput("left", this);
ml.registerInput("right", that);
ml.registerOutput("output");
MLOutput out = ml.executeScript(getScript(op));
RDD rows = out.getBinaryBlockedRDD("output").map(new GetMLBlock()).rdd();
StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
return new MLMatrix(this.sqlContext().createDataFrame(rows.toJavaRDD(), schema), mcOut, ml);
}
private MLMatrix scalarBinaryOp(Double scalar, String op, boolean isScalarLeft) throws IOException, DMLException, ParseException {
ml.reset();
ml.registerInput("left", this);
ml.registerOutput("output");
MLOutput out = ml.executeScript(getScalarBinaryScript(op, scalar, isScalarLeft));
RDD rows = out.getBinaryBlockedRDD("output").map(new GetMLBlock()).rdd();
StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
return new MLMatrix(this.sqlContext().createDataFrame(rows.toJavaRDD(), schema), mcOut, ml);
}
// ---------------------------------------------------
// Simple operator loading but doesnot utilize the optimizer
public MLMatrix $greater(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, ">");
}
public MLMatrix $less(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "<");
}
public MLMatrix $greater$eq(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, ">=");
}
public MLMatrix $less$eq(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "<=");
}
public MLMatrix $eq$eq(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "==");
}
public MLMatrix $bang$eq(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "!=");
}
public MLMatrix $up(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "^");
}
public MLMatrix exp(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "^");
}
public MLMatrix $plus(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "+");
}
public MLMatrix add(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "+");
}
public MLMatrix $minus(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "-");
}
public MLMatrix minus(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "-");
}
public MLMatrix $times(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "*");
}
public MLMatrix elementWiseMultiply(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "*");
}
public MLMatrix $div(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "/");
}
public MLMatrix divide(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "/");
}
public MLMatrix $percent$div$percent(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "%/%");
}
public MLMatrix integerDivision(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "%/%");
}
public MLMatrix $percent$percent(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "%%");
}
public MLMatrix modulus(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "%%");
}
public MLMatrix $percent$times$percent(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "%*%");
}
public MLMatrix multiply(MLMatrix that) throws IOException, DMLException, ParseException {
return matrixBinaryOp(that, "%*%");
}
public MLMatrix transpose() throws IOException, DMLException, ParseException {
ml.reset();
ml.registerInput("left", this);
ml.registerOutput("output");
String script = "left = read(\"\");"
+ "output = t(left); "
+ writeStmt;
MLOutput out = ml.executeScript(script);
RDD rows = out.getBinaryBlockedRDD("output").map(new GetMLBlock()).rdd();
StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
return new MLMatrix(this.sqlContext().createDataFrame(rows.toJavaRDD(), schema), mcOut, ml);
}
// TODO: For 'scalar op matrix' operations: Do implicit conversions
public MLMatrix $plus(Double scalar) throws IOException, DMLException, ParseException {
return scalarBinaryOp(scalar, "+", false);
}
public MLMatrix add(Double scalar) throws IOException, DMLException, ParseException {
return scalarBinaryOp(scalar, "+", false);
}
public MLMatrix $minus(Double scalar) throws IOException, DMLException, ParseException {
return scalarBinaryOp(scalar, "-", false);
}
public MLMatrix minus(Double scalar) throws IOException, DMLException, ParseException {
return scalarBinaryOp(scalar, "-", false);
}
public MLMatrix $times(Double scalar) throws IOException, DMLException, ParseException {
return scalarBinaryOp(scalar, "*", false);
}
public MLMatrix elementWiseMultiply(Double scalar) throws IOException, DMLException, ParseException {
return scalarBinaryOp(scalar, "*", false);
}
public MLMatrix $div(Double scalar) throws IOException, DMLException, ParseException {
return scalarBinaryOp(scalar, "/", false);
}
public MLMatrix divide(Double scalar) throws IOException, DMLException, ParseException {
return scalarBinaryOp(scalar, "/", false);
}
public MLMatrix $greater(Double scalar) throws IOException, DMLException, ParseException {
return scalarBinaryOp(scalar, ">", false);
}
public MLMatrix $less(Double scalar) throws IOException, DMLException, ParseException {
return scalarBinaryOp(scalar, "<", false);
}
public MLMatrix $greater$eq(Double scalar) throws IOException, DMLException, ParseException {
return scalarBinaryOp(scalar, ">=", false);
}
public MLMatrix $less$eq(Double scalar) throws IOException, DMLException, ParseException {
return scalarBinaryOp(scalar, "<=", false);
}
public MLMatrix $eq$eq(Double scalar) throws IOException, DMLException, ParseException {
return scalarBinaryOp(scalar, "==", false);
}
public MLMatrix $bang$eq(Double scalar) throws IOException, DMLException, ParseException {
return scalarBinaryOp(scalar, "!=", false);
}
}