All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.sysml.runtime.instructions.spark.ParameterizedBuiltinSPInstruction Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;

import scala.Tuple2;

import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.PartialAggregate.CorrectionLocationType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.parser.Statement;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.functionobjects.ParameterizedBuiltin;
import org.apache.sysml.runtime.functionobjects.ValueFunction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction;
import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysml.runtime.instructions.spark.functions.ExtractGroup.ExtractGroupBroadcast;
import org.apache.sysml.runtime.instructions.spark.functions.ExtractGroup.ExtractGroupJoin;
import org.apache.sysml.runtime.instructions.spark.functions.ExtractGroupNWeights;
import org.apache.sysml.runtime.instructions.spark.functions.PerformGroupByAggInCombiner;
import org.apache.sysml.runtime.instructions.spark.functions.PerformGroupByAggInReducer;
import org.apache.sysml.runtime.instructions.spark.functions.ReplicateVectorFunction;
import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.FrameBlock;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixCell;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.data.WeightedCell;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.CMOperator;
import org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.SimpleOperator;
import org.apache.sysml.runtime.transform.TfUtils;
import org.apache.sysml.runtime.transform.decode.Decoder;
import org.apache.sysml.runtime.transform.decode.DecoderFactory;
import org.apache.sysml.runtime.transform.encode.Encoder;
import org.apache.sysml.runtime.transform.encode.EncoderFactory;
import org.apache.sysml.runtime.transform.meta.TfMetaUtils;
import org.apache.sysml.runtime.transform.meta.TfOffsetMap;
import org.apache.sysml.runtime.util.DataConverter;
import org.apache.sysml.runtime.util.UtilFunctions;

public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction {
	protected HashMap params;
	// removeEmpty-specific attributes
	private boolean _bRmEmptyBC = false;

	private ParameterizedBuiltinSPInstruction(Operator op, HashMap paramsMap, CPOperand out,
			String opcode, String istr, boolean bRmEmptyBC) {
		super(op, null, null, out, opcode, istr);
		_sptype = SPINSTRUCTION_TYPE.ParameterizedBuiltin;
		params = paramsMap;
		_bRmEmptyBC = bRmEmptyBC;
	}

	public HashMap getParams() { return params; }
	
	public static HashMap constructParameterMap(String[] params) {
		// process all elements in "params" except first(opcode) and last(output)
		HashMap paramMap = new HashMap<>();
		
		// all parameters are of form 
		String[] parts;
		for ( int i=1; i <= params.length-2; i++ ) {
			parts = params[i].split(Lop.NAME_VALUE_SEPARATOR);
			paramMap.put(parts[0], parts[1]);
		}
		
		return paramMap;
	}
	
	public static ParameterizedBuiltinSPInstruction parseInstruction ( String str ) 
		throws DMLRuntimeException 
	{
		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
		// first part is always the opcode
		String opcode = parts[0];

		if( opcode.equalsIgnoreCase("mapgroupedagg") )
		{
			CPOperand target = new CPOperand( parts[1] ); 
			CPOperand groups = new CPOperand( parts[2] );
			CPOperand out = new CPOperand( parts[3] );

			HashMap paramsMap = new HashMap<>();
			paramsMap.put(Statement.GAGG_TARGET, target.getName());
			paramsMap.put(Statement.GAGG_GROUPS, groups.getName());
			paramsMap.put(Statement.GAGG_NUM_GROUPS, parts[4]);
			
			Operator op = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN);
			
			return new ParameterizedBuiltinSPInstruction(op, paramsMap, out, opcode, str, false);
		}
		else
		{
			// last part is always the output
			CPOperand out = new CPOperand( parts[parts.length-1] ); 

			// process remaining parts and build a hash map
			HashMap paramsMap = constructParameterMap(parts);

			// determine the appropriate value function
			ValueFunction func = null;
					
			if ( opcode.equalsIgnoreCase("groupedagg")) {
				// check for mandatory arguments
				String fnStr = paramsMap.get("fn");
				if ( fnStr == null ) 
					throw new DMLRuntimeException("Function parameter is missing in groupedAggregate.");
				if ( fnStr.equalsIgnoreCase("centralmoment") ) {
					if ( paramsMap.get("order") == null )
						throw new DMLRuntimeException("Mandatory \"order\" must be specified when fn=\"centralmoment\" in groupedAggregate.");
				}
				
				Operator op = GroupedAggregateInstruction.parseGroupedAggOperator(fnStr, paramsMap.get("order"));
				return new ParameterizedBuiltinSPInstruction(op, paramsMap, out, opcode, str, false);
			}
			else if(   opcode.equalsIgnoreCase("rmempty") ) 
			{
				boolean bRmEmptyBC = false; 
				if(parts.length > 6)
					bRmEmptyBC = Boolean.parseBoolean(parts[5]);
									
				func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
				return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str, bRmEmptyBC);
			}
			else if(   opcode.equalsIgnoreCase("rexpand") 
					|| opcode.equalsIgnoreCase("replace")
					|| opcode.equalsIgnoreCase("transformapply")
					|| opcode.equalsIgnoreCase("transformdecode")) 
			{
				func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
				return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str, false);
			}
			else {
				throw new DMLRuntimeException("Unknown opcode (" + opcode + ") for ParameterizedBuiltin Instruction.");
			}
		}
	}
	

	@Override 
	@SuppressWarnings("unchecked")
	public void processInstruction(ExecutionContext ec) 
		throws DMLRuntimeException 
	{
		SparkExecutionContext sec = (SparkExecutionContext)ec;
		String opcode = getOpcode();
		
		//opcode guaranteed to be a valid opcode (see parsing)
		if( opcode.equalsIgnoreCase("mapgroupedagg") )
		{		
			//get input rdd handle
			String targetVar = params.get(Statement.GAGG_TARGET);
			String groupsVar = params.get(Statement.GAGG_GROUPS);			
			JavaPairRDD target = sec.getBinaryBlockRDDHandleForVariable(targetVar);
			PartitionedBroadcast groups = sec.getBroadcastForVariable(groupsVar);
			MatrixCharacteristics mc1 = sec.getMatrixCharacteristics( targetVar );
			MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
			CPOperand ngrpOp = new CPOperand(params.get(Statement.GAGG_NUM_GROUPS));
			int ngroups = (int)sec.getScalarInput(ngrpOp.getName(), ngrpOp.getValueType(), ngrpOp.isLiteral()).getLongValue();
			
			//single-block aggregation
			if( ngroups <= mc1.getRowsPerBlock() && mc1.getCols() <= mc1.getColsPerBlock() ) {
				//execute map grouped aggregate
				JavaRDD out = target.map(new RDDMapGroupedAggFunction2(groups, _optr, ngroups));
				MatrixBlock out2 = RDDAggregateUtils.sumStable(out);
				
				//put output block into symbol table (no lineage because single block)
				//this also includes implicit maintenance of matrix characteristics
				sec.setMatrixOutput(output.getName(), out2, getExtendedOpcode());
			}
			//multi-block aggregation
			else {
				//execute map grouped aggregate
				JavaPairRDD out = 
						target.flatMapToPair(new RDDMapGroupedAggFunction(groups, _optr, 
								ngroups, mc1.getRowsPerBlock(), mc1.getColsPerBlock()));
				
				out = RDDAggregateUtils.sumByKeyStable(out, false);
				
				//updated characteristics and handle outputs
				mcOut.set(ngroups, mc1.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock(), -1);
				sec.setRDDHandleForVariable(output.getName(), out);			
				sec.addLineageRDD( output.getName(), targetVar );
				sec.addLineageBroadcast( output.getName(), groupsVar );	
			}
		}
		else if ( opcode.equalsIgnoreCase("groupedagg") ) 
		{	
			boolean broadcastGroups = Boolean.parseBoolean(params.get("broadcast"));
			
			//get input rdd handle
			String groupsVar = params.get(Statement.GAGG_GROUPS);
			JavaPairRDD target = sec.getBinaryBlockRDDHandleForVariable( params.get(Statement.GAGG_TARGET) );
			JavaPairRDD groups = broadcastGroups ? null : sec.getBinaryBlockRDDHandleForVariable( groupsVar );
			JavaPairRDD weights = null;
			
			MatrixCharacteristics mc1 = sec.getMatrixCharacteristics( params.get(Statement.GAGG_TARGET) );
			MatrixCharacteristics mc2 = sec.getMatrixCharacteristics( groupsVar );
			if(mc1.dimsKnown() && mc2.dimsKnown() && (mc1.getRows() != mc2.getRows() || mc2.getCols() !=1)) {
				throw new DMLRuntimeException("Grouped Aggregate dimension mismatch between target and groups.");
			}
			MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
			
			JavaPairRDD groupWeightedCells = null;
			
			// Step 1: First extract groupWeightedCells from group, target and weights
			if ( params.get(Statement.GAGG_WEIGHTS) != null ) {
				weights = sec.getBinaryBlockRDDHandleForVariable( params.get(Statement.GAGG_WEIGHTS) );
				
				MatrixCharacteristics mc3 = sec.getMatrixCharacteristics( params.get(Statement.GAGG_WEIGHTS) );
				if(mc1.dimsKnown() && mc3.dimsKnown() && (mc1.getRows() != mc3.getRows() || mc1.getCols() != mc3.getCols())) {
					throw new DMLRuntimeException("Grouped Aggregate dimension mismatch between target, groups, and weights.");
				}
				
				groupWeightedCells = groups.join(target).join(weights)
						.flatMapToPair(new ExtractGroupNWeights());	
			}
			else //input vector or matrix
			{
				String ngroupsStr = params.get(Statement.GAGG_NUM_GROUPS);
				long ngroups = (ngroupsStr != null) ? (long) Double.parseDouble(ngroupsStr) : -1;
				
				//execute basic grouped aggregate (extract and preagg)
				if( broadcastGroups ) {
					PartitionedBroadcast pbm = sec.getBroadcastForVariable(groupsVar);
					groupWeightedCells = target
							.flatMapToPair(new ExtractGroupBroadcast(pbm, mc1.getColsPerBlock(), ngroups, _optr));						
				}
				else { //general case
					
					//replicate groups if necessary
					if( mc1.getNumColBlocks() > 1 ) {
						groups = groups.flatMapToPair(
							new ReplicateVectorFunction(false, mc1.getNumColBlocks() ));
					}
					
					groupWeightedCells = groups.join(target)
							.flatMapToPair(new ExtractGroupJoin(mc1.getColsPerBlock(), ngroups, _optr));		
				}
			}
			
			// Step 2: Make sure we have brlen required while creating  
			if(mc1.getRowsPerBlock() == -1) {
				throw new DMLRuntimeException("The block sizes are not specified for grouped aggregate");
			}
			int brlen = mc1.getRowsPerBlock();
			
			// Step 3: Now perform grouped aggregate operation (either on combiner side or reducer side)
			JavaPairRDD out = null;
			if(_optr instanceof CMOperator && ((CMOperator) _optr).isPartialAggregateOperator() 
				|| _optr instanceof AggregateOperator ) {
				out = groupWeightedCells.reduceByKey(new PerformGroupByAggInCombiner(_optr))
						.mapValues(new CreateMatrixCell(brlen, _optr));
			}
			else {
				// Use groupby key because partial aggregation is not supported
				out = groupWeightedCells.groupByKey()
						.mapValues(new PerformGroupByAggInReducer(_optr))
						.mapValues(new CreateMatrixCell(brlen, _optr));
			}
			
			// Step 4: Set output characteristics and rdd handle 
			setOutputCharacteristicsForGroupedAgg(mc1, mcOut, out);
			
			//store output rdd handle
			sec.setRDDHandleForVariable(output.getName(), out);			
			sec.addLineageRDD( output.getName(), params.get(Statement.GAGG_TARGET) );
			sec.addLineage( output.getName(), groupsVar, broadcastGroups );
			if ( params.get(Statement.GAGG_WEIGHTS) != null ) {
				sec.addLineageRDD(output.getName(), params.get(Statement.GAGG_WEIGHTS) );
			}
		}
		else if ( opcode.equalsIgnoreCase("rmempty") ) 
		{
			String rddInVar = params.get("target");
			String rddOffVar = params.get("offset");
			
			boolean rows = sec.getScalarInput(params.get("margin"), ValueType.STRING, true).getStringValue().equals("rows");
			long maxDim = sec.getScalarInput(params.get("maxdim"), ValueType.DOUBLE, false).getLongValue();
			MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(rddInVar);
			
			if( maxDim > 0 ) //default case
			{
				//get input rdd handle
				JavaPairRDD in = sec.getBinaryBlockRDDHandleForVariable( rddInVar );
				JavaPairRDD off;
				PartitionedBroadcast broadcastOff;
				long brlen = mcIn.getRowsPerBlock();
				long bclen = mcIn.getColsPerBlock();
				long numRep = (long)Math.ceil( rows ? (double)mcIn.getCols()/bclen : (double)mcIn.getRows()/brlen);
				
				//execute remove empty rows/cols operation
				JavaPairRDD out;
	
				if(_bRmEmptyBC){
					broadcastOff = sec.getBroadcastForVariable( rddOffVar );
					// Broadcast offset vector
					out = in
						.flatMapToPair(new RDDRemoveEmptyFunctionInMem(rows, maxDim, brlen, bclen, broadcastOff));		
				}
				else {
					off = sec.getBinaryBlockRDDHandleForVariable( rddOffVar );
					out = in
						.join( off.flatMapToPair(new ReplicateVectorFunction(!rows,numRep)) )
						.flatMapToPair(new RDDRemoveEmptyFunction(rows, maxDim, brlen, bclen));		
				}				
	
				out = RDDAggregateUtils.mergeByKey(out, false);
				
				//store output rdd handle
				sec.setRDDHandleForVariable(output.getName(), out);
				sec.addLineageRDD(output.getName(), rddInVar);
				if(!_bRmEmptyBC)
					sec.addLineageRDD(output.getName(), rddOffVar);
				else
					sec.addLineageBroadcast(output.getName(), rddOffVar);
				
				//update output statistics (required for correctness)
				MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
				mcOut.set(rows?maxDim:mcIn.getRows(), rows?mcIn.getCols():maxDim, (int)brlen, (int)bclen, mcIn.getNonZeros());
			}
			else //special case: empty output (ensure valid dims)
			{
				MatrixBlock out = new MatrixBlock(rows?1:(int)mcIn.getRows(), rows?(int)mcIn.getCols():1, true); 
				sec.setMatrixOutput(output.getName(), out, getExtendedOpcode());
			}
		}
		else if ( opcode.equalsIgnoreCase("replace") ) 
		{	
			//get input rdd handle
			String rddVar = params.get("target");
			JavaPairRDD in1 = sec.getBinaryBlockRDDHandleForVariable( rddVar );
			MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(rddVar);
			
			//execute replace operation
			double pattern = Double.parseDouble( params.get("pattern") );
			double replacement = Double.parseDouble( params.get("replacement") );
			JavaPairRDD out = 
					in1.mapValues(new RDDReplaceFunction(pattern, replacement));
			
			//store output rdd handle
			sec.setRDDHandleForVariable(output.getName(), out);
			sec.addLineageRDD(output.getName(), rddVar);
			
			//update output statistics (required for correctness)
			MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
			mcOut.set(mcIn.getRows(), mcIn.getCols(), mcIn.getRowsPerBlock(), mcIn.getColsPerBlock(), (pattern!=0 && replacement!=0)?mcIn.getNonZeros():-1);
		}
		else if ( opcode.equalsIgnoreCase("rexpand") ) 
		{
			String rddInVar = params.get("target");
			
			//get input rdd handle
			JavaPairRDD in = sec.getBinaryBlockRDDHandleForVariable( rddInVar );
			MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(rddInVar);
			double maxVal = Double.parseDouble( params.get("max") );
			long lmaxVal = UtilFunctions.toLong(maxVal);
			boolean dirRows = params.get("dir").equals("rows");
			boolean cast = Boolean.parseBoolean(params.get("cast"));
			boolean ignore = Boolean.parseBoolean(params.get("ignore"));
			long brlen = mcIn.getRowsPerBlock();
			long bclen = mcIn.getColsPerBlock();
			
			//repartition input vector for higher degree of parallelism 
			//(avoid scenarios where few input partitions create huge outputs)
			MatrixCharacteristics mcTmp = new MatrixCharacteristics(dirRows?lmaxVal:mcIn.getRows(), 
					dirRows?mcIn.getRows():lmaxVal, (int)brlen, (int)bclen, mcIn.getRows());
			int numParts = (int)Math.min(SparkUtils.getNumPreferredPartitions(mcTmp, in), mcIn.getNumBlocks());
			if( numParts > in.getNumPartitions()*2 )
				in = in.repartition(numParts);
			
			//execute rexpand rows/cols operation (no shuffle required because outputs are
			//block-aligned with the input, i.e., one input block generates n output blocks)
			JavaPairRDD out = in
					.flatMapToPair(new RDDRExpandFunction(maxVal, dirRows, cast, ignore, brlen, bclen));		
			
			//store output rdd handle
			sec.setRDDHandleForVariable(output.getName(), out);
			sec.addLineageRDD(output.getName(), rddInVar);
			
			//update output statistics (required for correctness)
			MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
			mcOut.set(dirRows?lmaxVal:mcIn.getRows(), dirRows?mcIn.getRows():lmaxVal, (int)brlen, (int)bclen, -1);
		}
		else if ( opcode.equalsIgnoreCase("transformapply") ) 
		{
			//get input RDD and meta data
			FrameObject fo = sec.getFrameObject(params.get("target"));
			JavaPairRDD in = (JavaPairRDD)
					sec.getRDDHandleForFrameObject(fo, InputInfo.BinaryBlockInputInfo);
			FrameBlock meta = sec.getFrameInput(params.get("meta"));		
			MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(params.get("target"));
			MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
			String[] colnames = !TfMetaUtils.isIDSpec(params.get("spec")) ?
					in.lookup(1L).get(0).getColumnNames() : null; 
			
			//compute omit offset map for block shifts
			TfOffsetMap omap = null;
			if( TfMetaUtils.containsOmitSpec(params.get("spec"), colnames) ) {
				omap = new TfOffsetMap(SparkUtils.toIndexedLong(in.mapToPair(
					new RDDTransformApplyOffsetFunction(params.get("spec"), colnames)).collect()));
			}
				
			//create encoder broadcast (avoiding replication per task) 
			Encoder encoder = EncoderFactory.createEncoder(params.get("spec"), colnames,
					fo.getSchema(), (int)fo.getNumColumns(), meta);
			mcOut.setDimension(mcIn.getRows()-((omap!=null)?omap.getNumRmRows():0), encoder.getNumCols()); 
			Broadcast bmeta = sec.getSparkContext().broadcast(encoder);
			Broadcast bomap = (omap!=null) ? sec.getSparkContext().broadcast(omap) : null;
			
			//execute transform apply
			JavaPairRDD tmp = in
					.mapToPair(new RDDTransformApplyFunction(bmeta, bomap));
			JavaPairRDD out = FrameRDDConverterUtils
					.binaryBlockToMatrixBlock(tmp, mcOut, mcOut);
			
			//set output and maintain lineage/output characteristics
			sec.setRDDHandleForVariable(output.getName(), out);
			sec.addLineageRDD(output.getName(), params.get("target"));
			ec.releaseFrameInput(params.get("meta"));
		}
		else if ( opcode.equalsIgnoreCase("transformdecode") ) 
		{
			//get input RDD and meta data
			JavaPairRDD in = sec.getBinaryBlockRDDHandleForVariable(params.get("target"));
			MatrixCharacteristics mc = sec.getMatrixCharacteristics(params.get("target"));
			FrameBlock meta = sec.getFrameInput(params.get("meta"));		
			String[] colnames = meta.getColumnNames();
			
			//reblock if necessary (clen > bclen)
			if( mc.getCols() > mc.getNumColBlocks() ) {
				in = in.mapToPair(new RDDTransformDecodeExpandFunction(
						(int)mc.getCols(), mc.getColsPerBlock()));
				in = RDDAggregateUtils.mergeByKey(in, false);
			}
			
			//construct decoder and decode individual matrix blocks
			Decoder decoder = DecoderFactory.createDecoder(params.get("spec"), colnames, null, meta);
			JavaPairRDD out = in.mapToPair(
					new RDDTransformDecodeFunction(decoder, mc.getRowsPerBlock()));
			
			//set output and maintain lineage/output characteristics
			sec.setRDDHandleForVariable(output.getName(), out);
			sec.addLineageRDD(output.getName(), params.get("target"));
			ec.releaseFrameInput(params.get("meta"));
			sec.getMatrixCharacteristics(output.getName()).set(mc.getRows(), 
				meta.getNumColumns(), mc.getRowsPerBlock(), mc.getColsPerBlock(), -1);
			sec.getFrameObject(output.getName()).setSchema(decoder.getSchema());
		}
		else {
			throw new DMLRuntimeException("Unknown parameterized builtin opcode: "+opcode);
		}
	}

	public static class RDDReplaceFunction implements Function 
	{
		private static final long serialVersionUID = 6576713401901671659L;
		
		private double _pattern; 
		private double _replacement;
		
		public RDDReplaceFunction(double pattern, double replacement) 
		{
			_pattern = pattern;
			_replacement = replacement;
		}
		
		@Override
		public MatrixBlock call(MatrixBlock arg0) 
			throws Exception 
		{
			return (MatrixBlock) arg0.replaceOperations(new MatrixBlock(), _pattern, _replacement);
		}		
	}

	public static class RDDRemoveEmptyFunction implements PairFlatMapFunction>,MatrixIndexes,MatrixBlock> 
	{
		private static final long serialVersionUID = 4906304771183325289L;

		private boolean _rmRows; 
		private long _len;
		private long _brlen;
		private long _bclen;
				
		public RDDRemoveEmptyFunction(boolean rmRows, long len, long brlen, long bclen) 
		{
			_rmRows = rmRows;
			_len = len;
			_brlen = brlen;
			_bclen = bclen;
		}

		@Override
		public Iterator> call(Tuple2> arg0)
			throws Exception 
		{
			//prepare inputs (for internal api compatibility)
			IndexedMatrixValue data = SparkUtils.toIndexedMatrixBlock(arg0._1(),arg0._2()._1());
			IndexedMatrixValue offsets = SparkUtils.toIndexedMatrixBlock(arg0._1(),arg0._2()._2());
			
			//execute remove empty operations
			ArrayList out = new ArrayList<>();
			LibMatrixReorg.rmempty(data, offsets, _rmRows, _len, _brlen, _bclen, out);

			//prepare and return outputs
			return SparkUtils.fromIndexedMatrixBlock(out).iterator();
		}
	}

	public static class RDDRemoveEmptyFunctionInMem implements PairFlatMapFunction,MatrixIndexes,MatrixBlock> 
	{
		private static final long serialVersionUID = 4906304771183325289L;

		private boolean _rmRows; 
		private long _len;
		private long _brlen;
		private long _bclen;
		
		private PartitionedBroadcast _off = null;
				
		public RDDRemoveEmptyFunctionInMem(boolean rmRows, long len, long brlen, long bclen, PartitionedBroadcast off) 
		{
			_rmRows = rmRows;
			_len = len;
			_brlen = brlen;
			_bclen = bclen;
			_off = off;
		}

		@Override
		public Iterator> call(Tuple2 arg0)
			throws Exception 
		{
			//prepare inputs (for internal api compatibility)
			IndexedMatrixValue data = SparkUtils.toIndexedMatrixBlock(arg0._1(),arg0._2());
			//IndexedMatrixValue offsets = SparkUtils.toIndexedMatrixBlock(arg0._1(),arg0._2()._2());
			IndexedMatrixValue offsets = null;
			if(_rmRows)
				offsets = SparkUtils.toIndexedMatrixBlock(arg0._1(), _off.getBlock((int)arg0._1().getRowIndex(), 1));
			else
				offsets = SparkUtils.toIndexedMatrixBlock(arg0._1(), _off.getBlock(1, (int)arg0._1().getColumnIndex()));
			
			//execute remove empty operations
			ArrayList out = new ArrayList<>();
			LibMatrixReorg.rmempty(data, offsets, _rmRows, _len, _brlen, _bclen, out);

			//prepare and return outputs
			return SparkUtils.fromIndexedMatrixBlock(out).iterator();
		}
	}

	public static class RDDRExpandFunction implements PairFlatMapFunction,MatrixIndexes,MatrixBlock> 
	{
		private static final long serialVersionUID = -6153643261956222601L;
		
		private double _maxVal;
		private boolean _dirRows;
		private boolean _cast;
		private boolean _ignore;
		private long _brlen;
		private long _bclen;
		
		public RDDRExpandFunction(double maxVal, boolean dirRows, boolean cast, boolean ignore, long brlen, long bclen) 
		{
			_maxVal = maxVal;
			_dirRows = dirRows;
			_cast = cast;
			_ignore = ignore;
			_brlen = brlen;
			_bclen = bclen;
		}

		@Override
		public Iterator> call(Tuple2 arg0)
			throws Exception 
		{
			//prepare inputs (for internal api compatibility)
			IndexedMatrixValue data = SparkUtils.toIndexedMatrixBlock(arg0._1(),arg0._2());
			
			//execute rexpand operations
			ArrayList out = new ArrayList<>();
			LibMatrixReorg.rexpand(data, _maxVal, _dirRows, _cast, _ignore, _brlen, _bclen, out);
			
			//prepare and return outputs
			return SparkUtils.fromIndexedMatrixBlock(out).iterator();
		}
	}
	
	public static class RDDMapGroupedAggFunction implements PairFlatMapFunction,MatrixIndexes,MatrixBlock> 
	{
		private static final long serialVersionUID = 6795402640178679851L;
		
		private PartitionedBroadcast _pbm = null;
		private Operator _op = null;
		private int _ngroups = -1;
		private int _brlen = -1;
		private int _bclen = -1;
		
		public RDDMapGroupedAggFunction(PartitionedBroadcast pbm, Operator op, int ngroups, int brlen, int bclen) 
		{
			_pbm = pbm;
			_op = op;
			_ngroups = ngroups;
			_brlen = brlen;
			_bclen = bclen;
		}

		@Override
		public Iterator> call(Tuple2 arg0)
			throws Exception 
		{
			//get all inputs
			MatrixIndexes ix = arg0._1();
			MatrixBlock target = arg0._2();		
			MatrixBlock groups = _pbm.getBlock((int)ix.getRowIndex(), 1);
			
			//execute map grouped aggregate operations
			IndexedMatrixValue in1 = SparkUtils.toIndexedMatrixBlock(ix, target);
			ArrayList outlist = new ArrayList<>();
			OperationsOnMatrixValues.performMapGroupedAggregate(_op, in1, groups, _ngroups, _brlen, _bclen, outlist);
			
			//output all result blocks
			return SparkUtils.fromIndexedMatrixBlock(outlist).iterator();
		}
	}

	/**
	 * Similar to RDDMapGroupedAggFunction but single output block.
	 */
	public static class RDDMapGroupedAggFunction2 implements Function,MatrixBlock> 
	{
		private static final long serialVersionUID = -6820599604299797661L;
		
		private PartitionedBroadcast _pbm = null;
		private Operator _op = null;
		private int _ngroups = -1;
		
		public RDDMapGroupedAggFunction2(PartitionedBroadcast pbm, Operator op, int ngroups) {
			_pbm = pbm;
			_op = op;
			_ngroups = ngroups;
		}

		@Override
		public MatrixBlock call(Tuple2 arg0)
			throws Exception 
		{
			//get all inputs
			MatrixIndexes ix = arg0._1();
			MatrixBlock target = arg0._2();		
			MatrixBlock groups = _pbm.getBlock((int)ix.getRowIndex(), 1);
			
			//execute map grouped aggregate operations
			return groups.groupedAggOperations(target, null, new MatrixBlock(), _ngroups, _op);
		}
	}

	public static class CreateMatrixCell implements Function 
	{
		private static final long serialVersionUID = -5783727852453040737L;
		
		int brlen; Operator op;
		public CreateMatrixCell(int brlen, Operator op) {
			this.brlen = brlen;
			this.op = op;
		}

		@Override
		public MatrixCell call(WeightedCell kv) 
			throws Exception 
		{
			double val = -1;
			if(op instanceof CMOperator)
			{
				AggregateOperationTypes agg=((CMOperator)op).aggOpType;
				switch(agg)
				{
				case COUNT:
					val = kv.getWeight();
					break;
				case MEAN:
					val = kv.getValue();
					break;
				case CM2:
					val = kv.getValue()/ kv.getWeight();
					break;
				case CM3:
					val = kv.getValue()/ kv.getWeight();
					break;
				case CM4:
					val = kv.getValue()/ kv.getWeight();
					break;
				case VARIANCE:
					val = kv.getValue()/kv.getWeight();
					break;
				default:
					throw new DMLRuntimeException("Invalid aggreagte in CM_CV_Object: " + agg);
				}
			}
			else
			{
				//avoid division by 0
				val = kv.getValue()/kv.getWeight();
			}
			
			return new MatrixCell(val);
		}
	}

	public static class RDDTransformApplyFunction implements PairFunction,Long,FrameBlock> 
	{
		private static final long serialVersionUID = 5759813006068230916L;
		
		private Broadcast _bencoder = null;
		private Broadcast _omap = null;
		
		public RDDTransformApplyFunction(Broadcast bencoder, Broadcast omap) {
			_bencoder = bencoder;
			_omap = omap;
		}

		@Override
		public Tuple2 call(Tuple2 in) 
			throws Exception 
		{
			long key = in._1();
			FrameBlock blk = in._2();
			
			//execute block transform apply
			Encoder encoder = _bencoder.getValue();
			MatrixBlock tmp = encoder.apply(blk, new MatrixBlock(blk.getNumRows(), blk.getNumColumns(), false));
			
			//remap keys
			if( _omap != null ) {
				key = _omap.getValue().getOffset(key);
			}
			
			//convert to frameblock to reuse frame-matrix reblock
			return new Tuple2<>(key, 
					DataConverter.convertToFrameBlock(tmp));
		}
	}

	public static class RDDTransformApplyOffsetFunction implements PairFunction,Long,Long> 
	{
		private static final long serialVersionUID = 3450977356721057440L;
		
		private int[] _omitColList = null;
		
		public RDDTransformApplyOffsetFunction(String spec, String[] colnames) {
			try {
				_omitColList = TfMetaUtils.parseJsonIDList(spec, colnames, TfUtils.TXMETHOD_OMIT);
			} 
			catch (DMLRuntimeException e) {
				throw new RuntimeException(e);
			}
		}

		@Override
		public Tuple2 call(Tuple2 in) 
			throws Exception 
		{
			long key = in._1();
			long rmRows = 0;
			
			FrameBlock blk = in._2();
			
			for( int i=0; i(key, rmRows);
		}
	}

	public static class RDDTransformDecodeFunction implements PairFunction,Long,FrameBlock> 
	{
		private static final long serialVersionUID = -4797324742568170756L;
		
		private Decoder _decoder = null;
		private int _brlen = -1;
		
		public RDDTransformDecodeFunction(Decoder decoder, int brlen) {
			_decoder = decoder;
			_brlen = brlen;
		}

		@Override
		public Tuple2 call(Tuple2 in) 
			throws Exception 
		{
			long rix = UtilFunctions.computeCellIndex(in._1().getRowIndex(), _brlen, 0);
			FrameBlock fbout = _decoder.decode(in._2(), new FrameBlock(_decoder.getSchema()));
			fbout.setColumnNames(Arrays.copyOfRange(_decoder.getColnames(), 0, fbout.getNumColumns()));
			return new Tuple2<>(rix, fbout);
		}
	}
	
	public static class RDDTransformDecodeExpandFunction implements PairFunction,MatrixIndexes,MatrixBlock> 
	{
		private static final long serialVersionUID = -8187400248076127598L;
		
		private int _clen = -1;
		private int _bclen = -1;
		
		public RDDTransformDecodeExpandFunction(int clen, int bclen) {
			_clen = clen;
			_bclen = bclen;
		}

		@Override
		public Tuple2 call(Tuple2 in) 
			throws Exception 
		{
			MatrixIndexes inIx = in._1();
			MatrixBlock inBlk = in._2();
			
			//construct expanded block via leftindexing
			int cl = (int)UtilFunctions.computeCellIndex(inIx.getColumnIndex(), _bclen, 0)-1;
			int cu = (int)UtilFunctions.computeCellIndex(inIx.getColumnIndex(), _bclen, inBlk.getNumColumns()-1)-1;
			MatrixBlock out = new MatrixBlock(inBlk.getNumRows(), _clen, false);
			out = out.leftIndexingOperations(inBlk, 0, inBlk.getNumRows()-1, cl, cu, null, UpdateType.INPLACE_PINNED);
			
			return new Tuple2<>(new MatrixIndexes(inIx.getRowIndex(), 1), out);
		}
	}

	public void setOutputCharacteristicsForGroupedAgg(MatrixCharacteristics mc1, MatrixCharacteristics mcOut, JavaPairRDD out) 
		throws DMLRuntimeException 
	{
		if(!mcOut.dimsKnown()) {
			if(!mc1.dimsKnown()) {
				throw new DMLRuntimeException("The output dimensions are not specified for grouped aggregate");
			}
			
			if ( params.get(Statement.GAGG_NUM_GROUPS) != null) {
				int ngroups = (int) Double.parseDouble(params.get(Statement.GAGG_NUM_GROUPS));
				mcOut.set(ngroups, mc1.getCols(), -1, -1); //grouped aggregate with cell output
			}
			else {
				out = SparkUtils.cacheBinaryCellRDD(out);
				mcOut.set(SparkUtils.computeMatrixCharacteristics(out));
				mcOut.setBlockSize(-1, -1); //grouped aggregate with cell output
			}
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy