All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.api.MLMatrix Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.sysml.api;

import java.io.IOException;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SQLContext.QueryExecution;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.types.StructType;

import scala.Tuple2;

import org.apache.sysml.parser.DMLTranslator;
import org.apache.sysml.parser.ParseException;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.spark.functions.GetMIMBFromRow;
import org.apache.sysml.runtime.instructions.spark.functions.GetMLBlock;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;

/**
 * Experimental API: Might be discontinued in future release
 * 
 * This class serves four purposes:
 * 1. It allows SystemML to fit nicely in MLPipeline by reducing number of reblocks.
 * 2. It allows users to easily read and write matrices without worrying 
 * too much about format, metadata and type of underlying RDDs.
 * 3. It provides mechanism to convert to and from MLLib's BlockedMatrix format
 * 4. It provides off-the-shelf library for Distributed Blocked Matrix and reduces learning curve for using SystemML.
 * However, it is important to know that it is easy to abuse this off-the-shelf library and think it as replacement
 * to writing DML, which it is not. It does not provide any optimization between calls. A simple example
 * of the optimization that is conveniently skipped is: (t(m) %*% m)).
 * Also, note that this library is not thread-safe. The operator precedence is not exactly same as DML (as the precedence is
 * enforced by scala compiler), so please use appropriate brackets to enforce precedence. 

 import org.apache.sysml.api.{MLContext, MLMatrix}
 val ml = new MLContext(sc)
 val mat1 = ml.read(sqlContext, "V_small.csv", "csv")
 val mat2 = ml.read(sqlContext, "W_small.mtx", "binary")
 val result = mat1.transpose() %*% mat2
 result.write("Result_small.mtx", "text")
 
 */
public class MLMatrix extends DataFrame {
	private static final long serialVersionUID = -7005940673916671165L;
	protected static final Log LOG = LogFactory.getLog(DMLScript.class.getName());
	
	protected MatrixCharacteristics mc = null;
	protected MLContext ml = null;
	
	protected MLMatrix(SQLContext sqlContext, LogicalPlan logicalPlan, MLContext ml) {
		super(sqlContext, logicalPlan);
		this.ml = ml;
	}

	protected MLMatrix(SQLContext sqlContext, QueryExecution queryExecution, MLContext ml) {
		super(sqlContext, queryExecution);
		this.ml = ml;
	}
	
	// Only used internally to set a new MLMatrix after one of matrix operations.
	// Not to be used externally.
	protected MLMatrix(DataFrame df, MatrixCharacteristics mc, MLContext ml) throws DMLRuntimeException {
		super(df.sqlContext(), df.logicalPlan());
		this.mc = mc;
		this.ml = ml;
	}
	
	static String writeStmt = "write(output, \"tmp\", format=\"binary\", rows_in_block=" + DMLTranslator.DMLBlockSize + ", cols_in_block=" + DMLTranslator.DMLBlockSize + ");";
	
	// ------------------------------------------------------------------------------------------------
	
//	/**
//	 * Experimental unstable API: Converts our blocked matrix format to MLLib's format
//	 * @return
//	 */
//	public BlockMatrix toBlockedMatrix() {
//		JavaPairRDD blocks = getRDDLazily(this);
//		RDD, Matrix>> mllibBlocks = blocks.mapToPair(new GetMLLibBlocks(mc.getRows(), mc.getCols(), mc.getRowsPerBlock(), mc.getColsPerBlock())).rdd();
//		return new BlockMatrix(mllibBlocks, mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getRows(), mc.getCols());
//	}
	
	// ------------------------------------------------------------------------------------------------
	static MLMatrix createMLMatrix(MLContext ml, SQLContext sqlContext, JavaPairRDD blocks, MatrixCharacteristics mc) throws DMLRuntimeException {
		RDD rows = blocks.map(new GetMLBlock()).rdd();
		StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
		return new MLMatrix(sqlContext.createDataFrame(rows.toJavaRDD(), schema), mc, ml);
	}
	
	/**
	 * Convenient method to write a MLMatrix.
	 */
	public void write(String filePath, String format) throws IOException, DMLException, ParseException {
		ml.reset();
		ml.registerInput("left", this);
		ml.executeScript("left = read(\"\"); output=left; write(output, \"" + filePath + "\", format=\"" + format + "\");");
	}
	
	private double getScalarBuiltinFunctionResult(String fn) throws IOException, DMLException, ParseException {
		if(fn.compareTo("nrow") == 0 || fn.compareTo("ncol") == 0) {
			ml.reset();
			ml.registerInput("left", getRDDLazily(this), mc.getRows(), mc.getCols(), mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros());
			ml.registerOutput("output");
			String script = "left = read(\"\");"
					+ "val = " + fn + "(left); "
					+ "output = matrix(val, rows=1, cols=1); "
					+ writeStmt;
			MLOutput out = ml.executeScript(script);
			List> result = out.getBinaryBlockedRDD("output").collect();
			if(result == null || result.size() != 1) {
				throw new DMLRuntimeException("Error while computing the function: " + fn);
			}
			return result.get(0)._2.getValue(0, 0);
		}
		else {
			throw new DMLRuntimeException("The function " + fn + " is not yet supported in MLMatrix");
		}
	}
	
	/**
	 * Gets or computes the number of rows.
	 * @return
	 * @throws ParseException 
	 * @throws DMLException 
	 * @throws IOException 
	 */
	public long numRows() throws IOException, DMLException, ParseException {
		if(mc.rowsKnown()) {
			return mc.getRows();
		}
		else {
			return  (long) getScalarBuiltinFunctionResult("nrow");
		}
	}
	
	/**
	 * Gets or computes the number of columns.
	 * @return
	 * @throws ParseException 
	 * @throws DMLException 
	 * @throws IOException 
	 */
	public long numCols() throws IOException, DMLException, ParseException {
		if(mc.colsKnown()) {
			return mc.getCols();
		}
		else {
			return (long) getScalarBuiltinFunctionResult("ncol");
		}
	}
	
	public int rowsPerBlock() {
		return mc.getRowsPerBlock();
	}
	
	public int colsPerBlock() {
		return mc.getColsPerBlock();
	}
	
	private String getScript(String binaryOperator) {
		return 	"left = read(\"\");"
				+ "right = read(\"\");"
				+ "output = left " + binaryOperator + " right; "
				+ writeStmt;
	}
	
	private String getScalarBinaryScript(String binaryOperator, double scalar, boolean isScalarLeft) {
		if(isScalarLeft) {
			return 	"left = read(\"\");"
					+ "output = " + scalar + " " + binaryOperator + " left ;"
					+ writeStmt;
		}
		else {
			return 	"left = read(\"\");"
				+ "output = left " + binaryOperator + " " + scalar + ";"
				+ writeStmt;
		}
	}
	
	static JavaPairRDD getRDDLazily(MLMatrix mat) {
		return mat.rdd().toJavaRDD().mapToPair(new GetMIMBFromRow());
	}
	
	private MLMatrix matrixBinaryOp(MLMatrix that, String op) throws IOException, DMLException, ParseException {
		
		if(mc.getRowsPerBlock() != that.mc.getRowsPerBlock() || mc.getColsPerBlock() != that.mc.getColsPerBlock()) {
			throw new DMLRuntimeException("Incompatible block sizes: brlen:" + mc.getRowsPerBlock() + "!=" +  that.mc.getRowsPerBlock() + " || bclen:" + mc.getColsPerBlock() + "!=" + that.mc.getColsPerBlock());
		}
		
		if(op.compareTo("%*%") == 0) {
			if(mc.getCols() != that.mc.getRows()) {
				throw new DMLRuntimeException("Dimensions mismatch:" + mc.getCols() + "!=" +  that.mc.getRows());
			}
		}
		else {
			if(mc.getRows() != that.mc.getRows() || mc.getCols() != that.mc.getCols()) {
				throw new DMLRuntimeException("Dimensions mismatch:" + mc.getRows() + "!=" +  that.mc.getRows() + " || " + mc.getCols() + "!=" + that.mc.getCols());
			}
		}
		
		ml.reset();
		ml.registerInput("left", this);
		ml.registerInput("right", that);
		ml.registerOutput("output");
		MLOutput out = ml.executeScript(getScript(op));
		RDD rows = out.getBinaryBlockedRDD("output").map(new GetMLBlock()).rdd();
		StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
		MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
		return new MLMatrix(this.sqlContext().createDataFrame(rows.toJavaRDD(), schema), mcOut, ml);
	}
	
	private MLMatrix scalarBinaryOp(Double scalar, String op, boolean isScalarLeft) throws IOException, DMLException, ParseException {
		ml.reset();
		ml.registerInput("left", this);
		ml.registerOutput("output");
		MLOutput out = ml.executeScript(getScalarBinaryScript(op, scalar, isScalarLeft));
		RDD rows = out.getBinaryBlockedRDD("output").map(new GetMLBlock()).rdd();
		StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
		MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
		return new MLMatrix(this.sqlContext().createDataFrame(rows.toJavaRDD(), schema), mcOut, ml);
	}
	
	// ---------------------------------------------------
	// Simple operator loading but doesnot utilize the optimizer
	
	public MLMatrix $greater(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, ">");
	}
	
	public MLMatrix $less(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "<");
	}
	
	public MLMatrix $greater$eq(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, ">=");
	}
	
	public MLMatrix $less$eq(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "<=");
	}
	
	public MLMatrix $eq$eq(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "==");
	}
	
	public MLMatrix $bang$eq(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "!=");
	}
	
	public MLMatrix $up(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "^");
	}
	
	public MLMatrix exp(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "^");
	}
	
	public MLMatrix $plus(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "+");
	}
	
	public MLMatrix add(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "+");
	}
	
	public MLMatrix $minus(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "-");
	}
	
	public MLMatrix minus(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "-");
	}
	
	public MLMatrix $times(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "*");
	}
	
	public MLMatrix elementWiseMultiply(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "*");
	}
	
	public MLMatrix $div(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "/");
	}
	
	public MLMatrix divide(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "/");
	}
	
	public MLMatrix $percent$div$percent(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "%/%");
	}
	
	public MLMatrix integerDivision(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "%/%");
	}
	
	public MLMatrix $percent$percent(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "%%");
	}
	
	public MLMatrix modulus(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "%%");
	}
	
	public MLMatrix $percent$times$percent(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "%*%");
	}
	
	public MLMatrix multiply(MLMatrix that) throws IOException, DMLException, ParseException {
		return matrixBinaryOp(that, "%*%");
	}
	
	public MLMatrix transpose() throws IOException, DMLException, ParseException {
		ml.reset();
		ml.registerInput("left", this);
		ml.registerOutput("output");
		String script = "left = read(\"\");"
				+ "output = t(left); "
				+ writeStmt;
		MLOutput out = ml.executeScript(script);
		RDD rows = out.getBinaryBlockedRDD("output").map(new GetMLBlock()).rdd();
		StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
		MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
		return new MLMatrix(this.sqlContext().createDataFrame(rows.toJavaRDD(), schema), mcOut, ml);
	}
	
	// TODO: For 'scalar op matrix' operations: Do implicit conversions 
	public MLMatrix $plus(Double scalar) throws IOException, DMLException, ParseException {
		return scalarBinaryOp(scalar, "+", false);
	}
	
	public MLMatrix add(Double scalar) throws IOException, DMLException, ParseException {
		return scalarBinaryOp(scalar, "+", false);
	}
	
	public MLMatrix $minus(Double scalar) throws IOException, DMLException, ParseException {
		return scalarBinaryOp(scalar, "-", false);
	}
	
	public MLMatrix minus(Double scalar) throws IOException, DMLException, ParseException {
		return scalarBinaryOp(scalar, "-", false);
	}
	
	public MLMatrix $times(Double scalar) throws IOException, DMLException, ParseException {
		return scalarBinaryOp(scalar, "*", false);
	}
	
	public MLMatrix elementWiseMultiply(Double scalar) throws IOException, DMLException, ParseException {
		return scalarBinaryOp(scalar, "*", false);
	}
	
	public MLMatrix $div(Double scalar) throws IOException, DMLException, ParseException {
		return scalarBinaryOp(scalar, "/", false);
	}
	
	public MLMatrix divide(Double scalar) throws IOException, DMLException, ParseException {
		return scalarBinaryOp(scalar, "/", false);
	}
	
	public MLMatrix $greater(Double scalar) throws IOException, DMLException, ParseException {
		return scalarBinaryOp(scalar, ">", false);
	}
	
	public MLMatrix $less(Double scalar) throws IOException, DMLException, ParseException {
		return scalarBinaryOp(scalar, "<", false);
	}
	
	public MLMatrix $greater$eq(Double scalar) throws IOException, DMLException, ParseException {
		return scalarBinaryOp(scalar, ">=", false);
	}
	
	public MLMatrix $less$eq(Double scalar) throws IOException, DMLException, ParseException {
		return scalarBinaryOp(scalar, "<=", false);
	}
	
	public MLMatrix $eq$eq(Double scalar) throws IOException, DMLException, ParseException {
		return scalarBinaryOp(scalar, "==", false);
	}
	
	public MLMatrix $bang$eq(Double scalar) throws IOException, DMLException, ParseException {
		return scalarBinaryOp(scalar, "!=", false);
	}
	
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy