All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.runtime.instructions.spark.BinarySPInstruction Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.runtime.instructions.spark;

import org.apache.spark.api.java.JavaPairRDD;

import org.apache.sysml.lops.BinaryM.VectorType;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcastMatrix;
import org.apache.sysml.runtime.instructions.spark.functions.MatrixMatrixBinaryOpFunction;
import org.apache.sysml.runtime.instructions.spark.functions.MatrixScalarUnaryFunction;
import org.apache.sysml.runtime.instructions.spark.functions.MatrixVectorBinaryOpPartitionFunction;
import org.apache.sysml.runtime.instructions.spark.functions.OuterVectorBinaryOpFunction;
import org.apache.sysml.runtime.instructions.spark.functions.ReplicateVectorFunction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.ScalarOperator;

public abstract class BinarySPInstruction extends ComputationSPInstruction
{
	
	public BinarySPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr ){
		super(op, in1, in2, out, opcode, istr);
	}

	public BinarySPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr ){
		super(op, in1, in2, in3, out, opcode, istr);
	}
	
	/**
	 * 
	 * @param instr
	 * @param in1
	 * @param in2
	 * @param out
	 * @return
	 * @throws DMLRuntimeException
	 */
	protected static String parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand out)
		throws DMLRuntimeException
	{	
		String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
		InstructionUtils.checkNumFields ( parts, 3 );
		
		String opcode = parts[0];
		in1.split(parts[1]);
		in2.split(parts[2]);
		out.split(parts[3]);
		
		return opcode;
	}
	
	protected static String parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out)
		throws DMLRuntimeException
	{
		String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
		InstructionUtils.checkNumFields ( parts, 4 );
		
		String opcode = parts[0];
		in1.split(parts[1]);
		in2.split(parts[2]);
		in3.split(parts[3]);
		out.split(parts[4]);
		
		return opcode;
	}

	/**
	 * Common binary matrix-matrix process instruction
	 * 
	 * @param ec
	 * @throws DMLRuntimeException 
	 * @throws DMLUnsupportedOperationException 
	 */
	protected void processMatrixMatrixBinaryInstruction(ExecutionContext ec) 
		throws DMLRuntimeException, DMLUnsupportedOperationException
	{
		SparkExecutionContext sec = (SparkExecutionContext)ec;
		
		//sanity check dimensions
		checkMatrixMatrixBinaryCharacteristics(sec);
		
		// Get input RDDs
		String rddVar1 = input1.getName();
		String rddVar2 = input2.getName();
		JavaPairRDD in1 = sec.getBinaryBlockRDDHandleForVariable( rddVar1 );
		JavaPairRDD in2 = sec.getBinaryBlockRDDHandleForVariable( rddVar2 );
		MatrixCharacteristics mc1 = sec.getMatrixCharacteristics( rddVar1 );
		MatrixCharacteristics mc2 = sec.getMatrixCharacteristics( rddVar2 );
		
		BinaryOperator bop = (BinaryOperator) _optr;
	
		//vector replication if required (mv or outer operations)
		boolean rowvector = (mc2.getRows()==1 && mc1.getRows()>1);
		long numRepLeft = getNumReplicas(mc1, mc2, true);
		long numRepRight = getNumReplicas(mc1, mc2, false);
		if( numRepLeft > 1 )
			in1 = in1.flatMapToPair(new ReplicateVectorFunction(false, numRepLeft ));
		if( numRepRight > 1 )
			in2 = in2.flatMapToPair(new ReplicateVectorFunction(rowvector, numRepRight));
		
		//execute binary operation
		JavaPairRDD out = in1
				.join(in2)
				.mapValues(new MatrixMatrixBinaryOpFunction(bop));
		
		//set output RDD
		updateBinaryOutputMatrixCharacteristics(sec);
		sec.setRDDHandleForVariable(output.getName(), out);
		sec.addLineageRDD(output.getName(), rddVar1);
		sec.addLineageRDD(output.getName(), rddVar2);
	}
	
	/**
	 * 
	 * @param ec
	 * @param type 
	 * @throws DMLRuntimeException
	 * @throws DMLUnsupportedOperationException
	 */
	protected void processMatrixBVectorBinaryInstruction(ExecutionContext ec, VectorType vtype) 
		throws DMLRuntimeException, DMLUnsupportedOperationException
	{
		SparkExecutionContext sec = (SparkExecutionContext)ec;
		
		//sanity check dimensions
		checkMatrixMatrixBinaryCharacteristics(sec);

		//get input RDDs
		String rddVar = input1.getName(); 
		String bcastVar = input2.getName();
		JavaPairRDD in1 = sec.getBinaryBlockRDDHandleForVariable( rddVar );
		PartitionedBroadcastMatrix in2 = sec.getBroadcastForVariable( bcastVar );
		MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(rddVar);
		MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(bcastVar);
		
		BinaryOperator bop = (BinaryOperator) _optr;
		boolean isOuter = (mc1.getRows()>1 && mc1.getCols()==1 && mc2.getRows()==1 && mc2.getCols()>1);
		
		//execute map binary operation
		JavaPairRDD out = null;
		if( isOuter ) {
			out = in1.flatMapToPair(new OuterVectorBinaryOpFunction(bop, in2));
		}
		else { //default
			//note: we use mappartition in order to preserve partitioning information for
			//binary mv operations where the keys are guaranteed not to change, the reason
			//why we cannot use mapValues is the need for broadcast key lookups.
			//alternative: out = in1.mapToPair(new MatrixVectorBinaryOpFunction(bop, in2, vtype));
			out = in1.mapPartitionsToPair(
					new MatrixVectorBinaryOpPartitionFunction(bop, in2, vtype), true);
		}
		
		//set output RDD
		updateBinaryOutputMatrixCharacteristics(sec);
		sec.setRDDHandleForVariable(output.getName(), out);
		sec.addLineageRDD(output.getName(), rddVar);
		sec.addLineageBroadcast(output.getName(), bcastVar);
	}
	
	/**
	 * 
	 * @param ec
	 * @throws DMLRuntimeException
	 * @throws DMLUnsupportedOperationException
	 */
	protected void processMatrixScalarBinaryInstruction(ExecutionContext ec) 
		throws DMLRuntimeException, DMLUnsupportedOperationException
	{
		SparkExecutionContext sec = (SparkExecutionContext)ec;
	
		//get input RDD
		String rddVar = (input1.getDataType() == DataType.MATRIX) ? input1.getName() : input2.getName();
		JavaPairRDD in1 = sec.getBinaryBlockRDDHandleForVariable( rddVar );
		
		//get operator and scalar
		CPOperand scalar = ( input1.getDataType() == DataType.MATRIX ) ? input2 : input1;
		ScalarObject constant = (ScalarObject) ec.getScalarInput(scalar.getName(), scalar.getValueType(), scalar.isLiteral());
		ScalarOperator sc_op = (ScalarOperator) _optr;
		sc_op.setConstant(constant.getDoubleValue());
		
		//execute scalar matrix arithmetic instruction
		JavaPairRDD out = in1.mapValues( new MatrixScalarUnaryFunction(sc_op) );
			
		//put output RDD handle into symbol table
		updateUnaryOutputMatrixCharacteristics(sec, rddVar, output.getName());
		sec.setRDDHandleForVariable(output.getName(), out);
		sec.addLineageRDD(output.getName(), rddVar);
	}
	
	
	/**
	 * 
	 * @param sec
	 * @throws DMLRuntimeException
	 */
	protected void updateBinaryMMOutputMatrixCharacteristics(SparkExecutionContext sec, boolean checkCommonDim) 
		throws DMLRuntimeException
	{
		MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(input1.getName());
		MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(input2.getName());
		MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
		if(!mcOut.dimsKnown()) { 
			if( !mc1.dimsKnown() || !mc2.dimsKnown() )
				throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from inputs.");
			else if(mc1.getRowsPerBlock() != mc2.getRowsPerBlock() || mc1.getColsPerBlock() != mc2.getColsPerBlock())
				throw new DMLRuntimeException("Incompatible block sizes for BinarySPInstruction.");
			else if(checkCommonDim && mc1.getCols() != mc2.getRows())
				throw new DMLRuntimeException("Incompatible dimensions for BinarySPInstruction");
			else {
				mcOut.set(mc1.getRows(), mc2.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
			}
		}	
	}
	
	/**
	 * 
	 * @param sec
	 * @throws DMLRuntimeException
	 */
	protected void updateBinaryAppendOutputMatrixCharacteristics(SparkExecutionContext sec, boolean cbind) 
		throws DMLRuntimeException
	{
		MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(input1.getName());
		MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(input2.getName());
		MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
		
		//infer initially unknown dimensions from inputs
		if(!mcOut.dimsKnown()) { 
			if( !mc1.dimsKnown() || !mc2.dimsKnown() )
				throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from inputs.");
			
			if( cbind )
				mcOut.set(mc1.getRows(), mc1.getCols()+mc2.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
			else //rbind
				mcOut.set(mc1.getRows()+mc2.getRows(), mc1.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
		}	
		
		//infer initially unknown nnz from inputs
		if( !mcOut.nnzKnown() && mc1.nnzKnown() && mc2.nnzKnown() ) {
			mcOut.setNonZeros( mc1.getNonZeros() + mc2.getNonZeros() );
		}
	}

	/**
	 * 
	 * @param mc1
	 * @param mc2
	 * @param left
	 * @return
	 */
	protected long getNumReplicas(MatrixCharacteristics mc1, MatrixCharacteristics mc2, boolean left) 
	{
		if( left ) 
		{
			if(mc1.getCols()==1 ) //outer
				return (long) Math.ceil((double)mc2.getCols() / mc2.getColsPerBlock());	
		}
		else
		{
			if(mc2.getRows()==1 && mc1.getRows()>1) //outer, row vector
				return (long) Math.ceil((double)mc1.getRows() / mc1.getRowsPerBlock());	
			else if( mc2.getCols()==1 && mc1.getCols()>1 ) //col vector
				return (long) Math.ceil((double)mc1.getCols() / mc1.getColsPerBlock());			
		}
		
		return 1; //matrix-matrix
	}
	
	/**
	 * 
	 * @param sec
	 * @throws DMLRuntimeException
	 */
	protected void checkMatrixMatrixBinaryCharacteristics(SparkExecutionContext sec) 
		throws DMLRuntimeException 
	{
		MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(input1.getName());
		MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(input2.getName());
		
		//check for unknown input dimensions
		if( !(mc1.dimsKnown() && mc2.dimsKnown()) ){
			throw new DMLRuntimeException("Unknown dimensions matrix-matrix binary operations: "
					+ "[" + mc1.getRows() + "x" + mc1.getCols()  + " vs " + mc2.getRows() + "x" + mc2.getCols() + "]");
		}
		
		//check for dimension mismatch
		if( (mc1.getRows() != mc2.getRows() ||  mc1.getCols() != mc2.getCols())
			&& !(mc1.getRows() == mc2.getRows() && mc2.getCols()==1 ) //matrix-colvector
			&& !(mc1.getCols() == mc2.getCols() && mc2.getRows()==1 ) //matrix-rowvector
			&& !(mc1.getCols()==1 && mc2.getRows()==1) )     //outer colvector-rowvector 
		{
			throw new DMLRuntimeException("Dimensions mismatch matrix-matrix binary operations: "
					+ "[" + mc1.getRows() + "x" + mc1.getCols()  + " vs " + mc2.getRows() + "x" + mc2.getCols() + "]");
		}	
		
		if(mc1.getRowsPerBlock() != mc2.getRowsPerBlock() ||  mc1.getColsPerBlock() != mc2.getColsPerBlock()) {
			throw new DMLRuntimeException("Blocksize mismatch matrix-matrix binary operations: "
					+ "[" + mc1.getRowsPerBlock() + "x" + mc1.getColsPerBlock()  + " vs " + mc2.getRowsPerBlock() + "x" + mc2.getColsPerBlock() + "]");
		}	
	}
	
	/**
	 * 
	 * @param sec
	 * @param cbind
	 * @throws DMLRuntimeException
	 */
	protected void checkBinaryAppendInputCharacteristics(SparkExecutionContext sec, boolean cbind, boolean checkSingleBlk, boolean checkAligned) 
		throws DMLRuntimeException
	{
		MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(input1.getName());
		MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(input2.getName());
		
		if(!mc1.dimsKnown() || !mc2.dimsKnown()) {
			throw new DMLRuntimeException("The dimensions unknown for inputs");
		}
		else if(cbind && mc1.getRows() != mc2.getRows()) {
			throw new DMLRuntimeException("The number of rows of inputs should match for append-cbind instruction");
		}
		else if(!cbind && mc1.getCols() != mc2.getCols()) {
			throw new DMLRuntimeException("The number of columns of inputs should match for append-rbind instruction");
		}
		else if(mc1.getRowsPerBlock() != mc2.getRowsPerBlock() || mc1.getColsPerBlock() != mc2.getColsPerBlock()) {
			throw new DMLRuntimeException("The block sizes donot match for input matrices");
		}
		
		if( checkSingleBlk ) {
			if(mc1.getCols() + mc2.getCols() > mc1.getColsPerBlock())
				throw new DMLRuntimeException("Output must have at most one column block"); 
		}
		
		if( checkAligned ) {
			if( mc1.getCols() % mc1.getColsPerBlock() != 0 )
				throw new DMLRuntimeException("Input matrices are not aligned to blocksize boundaries. Wrong append selected");
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy