org.apache.sysml.runtime.instructions.spark.ParameterizedBuiltinSPInstruction Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.instructions.spark;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import scala.Tuple2;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.PartialAggregate.CorrectionLocationType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.parser.Statement;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.functionobjects.ParameterizedBuiltin;
import org.apache.sysml.runtime.functionobjects.ValueFunction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction;
import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysml.runtime.instructions.spark.functions.ExtractGroup.ExtractGroupBroadcast;
import org.apache.sysml.runtime.instructions.spark.functions.ExtractGroup.ExtractGroupJoin;
import org.apache.sysml.runtime.instructions.spark.functions.ExtractGroupNWeights;
import org.apache.sysml.runtime.instructions.spark.functions.PerformGroupByAggInCombiner;
import org.apache.sysml.runtime.instructions.spark.functions.PerformGroupByAggInReducer;
import org.apache.sysml.runtime.instructions.spark.functions.ReplicateVectorFunction;
import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.FrameBlock;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixCell;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.data.WeightedCell;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.CMOperator;
import org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.SimpleOperator;
import org.apache.sysml.runtime.transform.TfUtils;
import org.apache.sysml.runtime.transform.decode.Decoder;
import org.apache.sysml.runtime.transform.decode.DecoderFactory;
import org.apache.sysml.runtime.transform.encode.Encoder;
import org.apache.sysml.runtime.transform.encode.EncoderFactory;
import org.apache.sysml.runtime.transform.meta.TfMetaUtils;
import org.apache.sysml.runtime.transform.meta.TfOffsetMap;
import org.apache.sysml.runtime.util.DataConverter;
import org.apache.sysml.runtime.util.UtilFunctions;
public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction {
protected HashMap params;
// removeEmpty-specific attributes
private boolean _bRmEmptyBC = false;
private ParameterizedBuiltinSPInstruction(Operator op, HashMap paramsMap, CPOperand out,
String opcode, String istr, boolean bRmEmptyBC) {
super(op, null, null, out, opcode, istr);
_sptype = SPINSTRUCTION_TYPE.ParameterizedBuiltin;
params = paramsMap;
_bRmEmptyBC = bRmEmptyBC;
}
public HashMap getParams() { return params; }
public static HashMap constructParameterMap(String[] params) {
// process all elements in "params" except first(opcode) and last(output)
HashMap paramMap = new HashMap<>();
// all parameters are of form
String[] parts;
for ( int i=1; i <= params.length-2; i++ ) {
parts = params[i].split(Lop.NAME_VALUE_SEPARATOR);
paramMap.put(parts[0], parts[1]);
}
return paramMap;
}
public static ParameterizedBuiltinSPInstruction parseInstruction ( String str )
throws DMLRuntimeException
{
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
// first part is always the opcode
String opcode = parts[0];
if( opcode.equalsIgnoreCase("mapgroupedagg") )
{
CPOperand target = new CPOperand( parts[1] );
CPOperand groups = new CPOperand( parts[2] );
CPOperand out = new CPOperand( parts[3] );
HashMap paramsMap = new HashMap<>();
paramsMap.put(Statement.GAGG_TARGET, target.getName());
paramsMap.put(Statement.GAGG_GROUPS, groups.getName());
paramsMap.put(Statement.GAGG_NUM_GROUPS, parts[4]);
Operator op = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN);
return new ParameterizedBuiltinSPInstruction(op, paramsMap, out, opcode, str, false);
}
else
{
// last part is always the output
CPOperand out = new CPOperand( parts[parts.length-1] );
// process remaining parts and build a hash map
HashMap paramsMap = constructParameterMap(parts);
// determine the appropriate value function
ValueFunction func = null;
if ( opcode.equalsIgnoreCase("groupedagg")) {
// check for mandatory arguments
String fnStr = paramsMap.get("fn");
if ( fnStr == null )
throw new DMLRuntimeException("Function parameter is missing in groupedAggregate.");
if ( fnStr.equalsIgnoreCase("centralmoment") ) {
if ( paramsMap.get("order") == null )
throw new DMLRuntimeException("Mandatory \"order\" must be specified when fn=\"centralmoment\" in groupedAggregate.");
}
Operator op = GroupedAggregateInstruction.parseGroupedAggOperator(fnStr, paramsMap.get("order"));
return new ParameterizedBuiltinSPInstruction(op, paramsMap, out, opcode, str, false);
}
else if( opcode.equalsIgnoreCase("rmempty") )
{
boolean bRmEmptyBC = false;
if(parts.length > 6)
bRmEmptyBC = Boolean.parseBoolean(parts[5]);
func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str, bRmEmptyBC);
}
else if( opcode.equalsIgnoreCase("rexpand")
|| opcode.equalsIgnoreCase("replace")
|| opcode.equalsIgnoreCase("transformapply")
|| opcode.equalsIgnoreCase("transformdecode"))
{
func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str, false);
}
else {
throw new DMLRuntimeException("Unknown opcode (" + opcode + ") for ParameterizedBuiltin Instruction.");
}
}
}
@Override
@SuppressWarnings("unchecked")
public void processInstruction(ExecutionContext ec)
throws DMLRuntimeException
{
SparkExecutionContext sec = (SparkExecutionContext)ec;
String opcode = getOpcode();
//opcode guaranteed to be a valid opcode (see parsing)
if( opcode.equalsIgnoreCase("mapgroupedagg") )
{
//get input rdd handle
String targetVar = params.get(Statement.GAGG_TARGET);
String groupsVar = params.get(Statement.GAGG_GROUPS);
JavaPairRDD target = sec.getBinaryBlockRDDHandleForVariable(targetVar);
PartitionedBroadcast groups = sec.getBroadcastForVariable(groupsVar);
MatrixCharacteristics mc1 = sec.getMatrixCharacteristics( targetVar );
MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
CPOperand ngrpOp = new CPOperand(params.get(Statement.GAGG_NUM_GROUPS));
int ngroups = (int)sec.getScalarInput(ngrpOp.getName(), ngrpOp.getValueType(), ngrpOp.isLiteral()).getLongValue();
//single-block aggregation
if( ngroups <= mc1.getRowsPerBlock() && mc1.getCols() <= mc1.getColsPerBlock() ) {
//execute map grouped aggregate
JavaRDD out = target.map(new RDDMapGroupedAggFunction2(groups, _optr, ngroups));
MatrixBlock out2 = RDDAggregateUtils.sumStable(out);
//put output block into symbol table (no lineage because single block)
//this also includes implicit maintenance of matrix characteristics
sec.setMatrixOutput(output.getName(), out2, getExtendedOpcode());
}
//multi-block aggregation
else {
//execute map grouped aggregate
JavaPairRDD out =
target.flatMapToPair(new RDDMapGroupedAggFunction(groups, _optr,
ngroups, mc1.getRowsPerBlock(), mc1.getColsPerBlock()));
out = RDDAggregateUtils.sumByKeyStable(out, false);
//updated characteristics and handle outputs
mcOut.set(ngroups, mc1.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock(), -1);
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD( output.getName(), targetVar );
sec.addLineageBroadcast( output.getName(), groupsVar );
}
}
else if ( opcode.equalsIgnoreCase("groupedagg") )
{
boolean broadcastGroups = Boolean.parseBoolean(params.get("broadcast"));
//get input rdd handle
String groupsVar = params.get(Statement.GAGG_GROUPS);
JavaPairRDD target = sec.getBinaryBlockRDDHandleForVariable( params.get(Statement.GAGG_TARGET) );
JavaPairRDD groups = broadcastGroups ? null : sec.getBinaryBlockRDDHandleForVariable( groupsVar );
JavaPairRDD weights = null;
MatrixCharacteristics mc1 = sec.getMatrixCharacteristics( params.get(Statement.GAGG_TARGET) );
MatrixCharacteristics mc2 = sec.getMatrixCharacteristics( groupsVar );
if(mc1.dimsKnown() && mc2.dimsKnown() && (mc1.getRows() != mc2.getRows() || mc2.getCols() !=1)) {
throw new DMLRuntimeException("Grouped Aggregate dimension mismatch between target and groups.");
}
MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
JavaPairRDD groupWeightedCells = null;
// Step 1: First extract groupWeightedCells from group, target and weights
if ( params.get(Statement.GAGG_WEIGHTS) != null ) {
weights = sec.getBinaryBlockRDDHandleForVariable( params.get(Statement.GAGG_WEIGHTS) );
MatrixCharacteristics mc3 = sec.getMatrixCharacteristics( params.get(Statement.GAGG_WEIGHTS) );
if(mc1.dimsKnown() && mc3.dimsKnown() && (mc1.getRows() != mc3.getRows() || mc1.getCols() != mc3.getCols())) {
throw new DMLRuntimeException("Grouped Aggregate dimension mismatch between target, groups, and weights.");
}
groupWeightedCells = groups.join(target).join(weights)
.flatMapToPair(new ExtractGroupNWeights());
}
else //input vector or matrix
{
String ngroupsStr = params.get(Statement.GAGG_NUM_GROUPS);
long ngroups = (ngroupsStr != null) ? (long) Double.parseDouble(ngroupsStr) : -1;
//execute basic grouped aggregate (extract and preagg)
if( broadcastGroups ) {
PartitionedBroadcast pbm = sec.getBroadcastForVariable(groupsVar);
groupWeightedCells = target
.flatMapToPair(new ExtractGroupBroadcast(pbm, mc1.getColsPerBlock(), ngroups, _optr));
}
else { //general case
//replicate groups if necessary
if( mc1.getNumColBlocks() > 1 ) {
groups = groups.flatMapToPair(
new ReplicateVectorFunction(false, mc1.getNumColBlocks() ));
}
groupWeightedCells = groups.join(target)
.flatMapToPair(new ExtractGroupJoin(mc1.getColsPerBlock(), ngroups, _optr));
}
}
// Step 2: Make sure we have brlen required while creating
if(mc1.getRowsPerBlock() == -1) {
throw new DMLRuntimeException("The block sizes are not specified for grouped aggregate");
}
int brlen = mc1.getRowsPerBlock();
// Step 3: Now perform grouped aggregate operation (either on combiner side or reducer side)
JavaPairRDD out = null;
if(_optr instanceof CMOperator && ((CMOperator) _optr).isPartialAggregateOperator()
|| _optr instanceof AggregateOperator ) {
out = groupWeightedCells.reduceByKey(new PerformGroupByAggInCombiner(_optr))
.mapValues(new CreateMatrixCell(brlen, _optr));
}
else {
// Use groupby key because partial aggregation is not supported
out = groupWeightedCells.groupByKey()
.mapValues(new PerformGroupByAggInReducer(_optr))
.mapValues(new CreateMatrixCell(brlen, _optr));
}
// Step 4: Set output characteristics and rdd handle
setOutputCharacteristicsForGroupedAgg(mc1, mcOut, out);
//store output rdd handle
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD( output.getName(), params.get(Statement.GAGG_TARGET) );
sec.addLineage( output.getName(), groupsVar, broadcastGroups );
if ( params.get(Statement.GAGG_WEIGHTS) != null ) {
sec.addLineageRDD(output.getName(), params.get(Statement.GAGG_WEIGHTS) );
}
}
else if ( opcode.equalsIgnoreCase("rmempty") )
{
String rddInVar = params.get("target");
String rddOffVar = params.get("offset");
boolean rows = sec.getScalarInput(params.get("margin"), ValueType.STRING, true).getStringValue().equals("rows");
long maxDim = sec.getScalarInput(params.get("maxdim"), ValueType.DOUBLE, false).getLongValue();
MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(rddInVar);
if( maxDim > 0 ) //default case
{
//get input rdd handle
JavaPairRDD in = sec.getBinaryBlockRDDHandleForVariable( rddInVar );
JavaPairRDD off;
PartitionedBroadcast broadcastOff;
long brlen = mcIn.getRowsPerBlock();
long bclen = mcIn.getColsPerBlock();
long numRep = (long)Math.ceil( rows ? (double)mcIn.getCols()/bclen : (double)mcIn.getRows()/brlen);
//execute remove empty rows/cols operation
JavaPairRDD out;
if(_bRmEmptyBC){
broadcastOff = sec.getBroadcastForVariable( rddOffVar );
// Broadcast offset vector
out = in
.flatMapToPair(new RDDRemoveEmptyFunctionInMem(rows, maxDim, brlen, bclen, broadcastOff));
}
else {
off = sec.getBinaryBlockRDDHandleForVariable( rddOffVar );
out = in
.join( off.flatMapToPair(new ReplicateVectorFunction(!rows,numRep)) )
.flatMapToPair(new RDDRemoveEmptyFunction(rows, maxDim, brlen, bclen));
}
out = RDDAggregateUtils.mergeByKey(out, false);
//store output rdd handle
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), rddInVar);
if(!_bRmEmptyBC)
sec.addLineageRDD(output.getName(), rddOffVar);
else
sec.addLineageBroadcast(output.getName(), rddOffVar);
//update output statistics (required for correctness)
MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
mcOut.set(rows?maxDim:mcIn.getRows(), rows?mcIn.getCols():maxDim, (int)brlen, (int)bclen, mcIn.getNonZeros());
}
else //special case: empty output (ensure valid dims)
{
MatrixBlock out = new MatrixBlock(rows?1:(int)mcIn.getRows(), rows?(int)mcIn.getCols():1, true);
sec.setMatrixOutput(output.getName(), out, getExtendedOpcode());
}
}
else if ( opcode.equalsIgnoreCase("replace") )
{
//get input rdd handle
String rddVar = params.get("target");
JavaPairRDD in1 = sec.getBinaryBlockRDDHandleForVariable( rddVar );
MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(rddVar);
//execute replace operation
double pattern = Double.parseDouble( params.get("pattern") );
double replacement = Double.parseDouble( params.get("replacement") );
JavaPairRDD out =
in1.mapValues(new RDDReplaceFunction(pattern, replacement));
//store output rdd handle
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), rddVar);
//update output statistics (required for correctness)
MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
mcOut.set(mcIn.getRows(), mcIn.getCols(), mcIn.getRowsPerBlock(), mcIn.getColsPerBlock(), (pattern!=0 && replacement!=0)?mcIn.getNonZeros():-1);
}
else if ( opcode.equalsIgnoreCase("rexpand") )
{
String rddInVar = params.get("target");
//get input rdd handle
JavaPairRDD in = sec.getBinaryBlockRDDHandleForVariable( rddInVar );
MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(rddInVar);
double maxVal = Double.parseDouble( params.get("max") );
long lmaxVal = UtilFunctions.toLong(maxVal);
boolean dirRows = params.get("dir").equals("rows");
boolean cast = Boolean.parseBoolean(params.get("cast"));
boolean ignore = Boolean.parseBoolean(params.get("ignore"));
long brlen = mcIn.getRowsPerBlock();
long bclen = mcIn.getColsPerBlock();
//repartition input vector for higher degree of parallelism
//(avoid scenarios where few input partitions create huge outputs)
MatrixCharacteristics mcTmp = new MatrixCharacteristics(dirRows?lmaxVal:mcIn.getRows(),
dirRows?mcIn.getRows():lmaxVal, (int)brlen, (int)bclen, mcIn.getRows());
int numParts = (int)Math.min(SparkUtils.getNumPreferredPartitions(mcTmp, in), mcIn.getNumBlocks());
if( numParts > in.getNumPartitions()*2 )
in = in.repartition(numParts);
//execute rexpand rows/cols operation (no shuffle required because outputs are
//block-aligned with the input, i.e., one input block generates n output blocks)
JavaPairRDD out = in
.flatMapToPair(new RDDRExpandFunction(maxVal, dirRows, cast, ignore, brlen, bclen));
//store output rdd handle
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), rddInVar);
//update output statistics (required for correctness)
MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
mcOut.set(dirRows?lmaxVal:mcIn.getRows(), dirRows?mcIn.getRows():lmaxVal, (int)brlen, (int)bclen, -1);
}
else if ( opcode.equalsIgnoreCase("transformapply") )
{
//get input RDD and meta data
FrameObject fo = sec.getFrameObject(params.get("target"));
JavaPairRDD in = (JavaPairRDD)
sec.getRDDHandleForFrameObject(fo, InputInfo.BinaryBlockInputInfo);
FrameBlock meta = sec.getFrameInput(params.get("meta"));
MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(params.get("target"));
MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
String[] colnames = !TfMetaUtils.isIDSpec(params.get("spec")) ?
in.lookup(1L).get(0).getColumnNames() : null;
//compute omit offset map for block shifts
TfOffsetMap omap = null;
if( TfMetaUtils.containsOmitSpec(params.get("spec"), colnames) ) {
omap = new TfOffsetMap(SparkUtils.toIndexedLong(in.mapToPair(
new RDDTransformApplyOffsetFunction(params.get("spec"), colnames)).collect()));
}
//create encoder broadcast (avoiding replication per task)
Encoder encoder = EncoderFactory.createEncoder(params.get("spec"), colnames,
fo.getSchema(), (int)fo.getNumColumns(), meta);
mcOut.setDimension(mcIn.getRows()-((omap!=null)?omap.getNumRmRows():0), encoder.getNumCols());
Broadcast bmeta = sec.getSparkContext().broadcast(encoder);
Broadcast bomap = (omap!=null) ? sec.getSparkContext().broadcast(omap) : null;
//execute transform apply
JavaPairRDD tmp = in
.mapToPair(new RDDTransformApplyFunction(bmeta, bomap));
JavaPairRDD out = FrameRDDConverterUtils
.binaryBlockToMatrixBlock(tmp, mcOut, mcOut);
//set output and maintain lineage/output characteristics
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), params.get("target"));
ec.releaseFrameInput(params.get("meta"));
}
else if ( opcode.equalsIgnoreCase("transformdecode") )
{
//get input RDD and meta data
JavaPairRDD in = sec.getBinaryBlockRDDHandleForVariable(params.get("target"));
MatrixCharacteristics mc = sec.getMatrixCharacteristics(params.get("target"));
FrameBlock meta = sec.getFrameInput(params.get("meta"));
String[] colnames = meta.getColumnNames();
//reblock if necessary (clen > bclen)
if( mc.getCols() > mc.getNumColBlocks() ) {
in = in.mapToPair(new RDDTransformDecodeExpandFunction(
(int)mc.getCols(), mc.getColsPerBlock()));
in = RDDAggregateUtils.mergeByKey(in, false);
}
//construct decoder and decode individual matrix blocks
Decoder decoder = DecoderFactory.createDecoder(params.get("spec"), colnames, null, meta);
JavaPairRDD out = in.mapToPair(
new RDDTransformDecodeFunction(decoder, mc.getRowsPerBlock()));
//set output and maintain lineage/output characteristics
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), params.get("target"));
ec.releaseFrameInput(params.get("meta"));
sec.getMatrixCharacteristics(output.getName()).set(mc.getRows(),
meta.getNumColumns(), mc.getRowsPerBlock(), mc.getColsPerBlock(), -1);
sec.getFrameObject(output.getName()).setSchema(decoder.getSchema());
}
else {
throw new DMLRuntimeException("Unknown parameterized builtin opcode: "+opcode);
}
}
public static class RDDReplaceFunction implements Function
{
private static final long serialVersionUID = 6576713401901671659L;
private double _pattern;
private double _replacement;
public RDDReplaceFunction(double pattern, double replacement)
{
_pattern = pattern;
_replacement = replacement;
}
@Override
public MatrixBlock call(MatrixBlock arg0)
throws Exception
{
return (MatrixBlock) arg0.replaceOperations(new MatrixBlock(), _pattern, _replacement);
}
}
public static class RDDRemoveEmptyFunction implements PairFlatMapFunction>,MatrixIndexes,MatrixBlock>
{
private static final long serialVersionUID = 4906304771183325289L;
private boolean _rmRows;
private long _len;
private long _brlen;
private long _bclen;
public RDDRemoveEmptyFunction(boolean rmRows, long len, long brlen, long bclen)
{
_rmRows = rmRows;
_len = len;
_brlen = brlen;
_bclen = bclen;
}
@Override
public Iterator> call(Tuple2> arg0)
throws Exception
{
//prepare inputs (for internal api compatibility)
IndexedMatrixValue data = SparkUtils.toIndexedMatrixBlock(arg0._1(),arg0._2()._1());
IndexedMatrixValue offsets = SparkUtils.toIndexedMatrixBlock(arg0._1(),arg0._2()._2());
//execute remove empty operations
ArrayList out = new ArrayList<>();
LibMatrixReorg.rmempty(data, offsets, _rmRows, _len, _brlen, _bclen, out);
//prepare and return outputs
return SparkUtils.fromIndexedMatrixBlock(out).iterator();
}
}
public static class RDDRemoveEmptyFunctionInMem implements PairFlatMapFunction,MatrixIndexes,MatrixBlock>
{
private static final long serialVersionUID = 4906304771183325289L;
private boolean _rmRows;
private long _len;
private long _brlen;
private long _bclen;
private PartitionedBroadcast _off = null;
public RDDRemoveEmptyFunctionInMem(boolean rmRows, long len, long brlen, long bclen, PartitionedBroadcast off)
{
_rmRows = rmRows;
_len = len;
_brlen = brlen;
_bclen = bclen;
_off = off;
}
@Override
public Iterator> call(Tuple2 arg0)
throws Exception
{
//prepare inputs (for internal api compatibility)
IndexedMatrixValue data = SparkUtils.toIndexedMatrixBlock(arg0._1(),arg0._2());
//IndexedMatrixValue offsets = SparkUtils.toIndexedMatrixBlock(arg0._1(),arg0._2()._2());
IndexedMatrixValue offsets = null;
if(_rmRows)
offsets = SparkUtils.toIndexedMatrixBlock(arg0._1(), _off.getBlock((int)arg0._1().getRowIndex(), 1));
else
offsets = SparkUtils.toIndexedMatrixBlock(arg0._1(), _off.getBlock(1, (int)arg0._1().getColumnIndex()));
//execute remove empty operations
ArrayList out = new ArrayList<>();
LibMatrixReorg.rmempty(data, offsets, _rmRows, _len, _brlen, _bclen, out);
//prepare and return outputs
return SparkUtils.fromIndexedMatrixBlock(out).iterator();
}
}
public static class RDDRExpandFunction implements PairFlatMapFunction,MatrixIndexes,MatrixBlock>
{
private static final long serialVersionUID = -6153643261956222601L;
private double _maxVal;
private boolean _dirRows;
private boolean _cast;
private boolean _ignore;
private long _brlen;
private long _bclen;
public RDDRExpandFunction(double maxVal, boolean dirRows, boolean cast, boolean ignore, long brlen, long bclen)
{
_maxVal = maxVal;
_dirRows = dirRows;
_cast = cast;
_ignore = ignore;
_brlen = brlen;
_bclen = bclen;
}
@Override
public Iterator> call(Tuple2 arg0)
throws Exception
{
//prepare inputs (for internal api compatibility)
IndexedMatrixValue data = SparkUtils.toIndexedMatrixBlock(arg0._1(),arg0._2());
//execute rexpand operations
ArrayList out = new ArrayList<>();
LibMatrixReorg.rexpand(data, _maxVal, _dirRows, _cast, _ignore, _brlen, _bclen, out);
//prepare and return outputs
return SparkUtils.fromIndexedMatrixBlock(out).iterator();
}
}
public static class RDDMapGroupedAggFunction implements PairFlatMapFunction,MatrixIndexes,MatrixBlock>
{
private static final long serialVersionUID = 6795402640178679851L;
private PartitionedBroadcast _pbm = null;
private Operator _op = null;
private int _ngroups = -1;
private int _brlen = -1;
private int _bclen = -1;
public RDDMapGroupedAggFunction(PartitionedBroadcast pbm, Operator op, int ngroups, int brlen, int bclen)
{
_pbm = pbm;
_op = op;
_ngroups = ngroups;
_brlen = brlen;
_bclen = bclen;
}
@Override
public Iterator> call(Tuple2 arg0)
throws Exception
{
//get all inputs
MatrixIndexes ix = arg0._1();
MatrixBlock target = arg0._2();
MatrixBlock groups = _pbm.getBlock((int)ix.getRowIndex(), 1);
//execute map grouped aggregate operations
IndexedMatrixValue in1 = SparkUtils.toIndexedMatrixBlock(ix, target);
ArrayList outlist = new ArrayList<>();
OperationsOnMatrixValues.performMapGroupedAggregate(_op, in1, groups, _ngroups, _brlen, _bclen, outlist);
//output all result blocks
return SparkUtils.fromIndexedMatrixBlock(outlist).iterator();
}
}
/**
* Similar to RDDMapGroupedAggFunction but single output block.
*/
public static class RDDMapGroupedAggFunction2 implements Function,MatrixBlock>
{
private static final long serialVersionUID = -6820599604299797661L;
private PartitionedBroadcast _pbm = null;
private Operator _op = null;
private int _ngroups = -1;
public RDDMapGroupedAggFunction2(PartitionedBroadcast pbm, Operator op, int ngroups) {
_pbm = pbm;
_op = op;
_ngroups = ngroups;
}
@Override
public MatrixBlock call(Tuple2 arg0)
throws Exception
{
//get all inputs
MatrixIndexes ix = arg0._1();
MatrixBlock target = arg0._2();
MatrixBlock groups = _pbm.getBlock((int)ix.getRowIndex(), 1);
//execute map grouped aggregate operations
return groups.groupedAggOperations(target, null, new MatrixBlock(), _ngroups, _op);
}
}
public static class CreateMatrixCell implements Function
{
private static final long serialVersionUID = -5783727852453040737L;
int brlen; Operator op;
public CreateMatrixCell(int brlen, Operator op) {
this.brlen = brlen;
this.op = op;
}
@Override
public MatrixCell call(WeightedCell kv)
throws Exception
{
double val = -1;
if(op instanceof CMOperator)
{
AggregateOperationTypes agg=((CMOperator)op).aggOpType;
switch(agg)
{
case COUNT:
val = kv.getWeight();
break;
case MEAN:
val = kv.getValue();
break;
case CM2:
val = kv.getValue()/ kv.getWeight();
break;
case CM3:
val = kv.getValue()/ kv.getWeight();
break;
case CM4:
val = kv.getValue()/ kv.getWeight();
break;
case VARIANCE:
val = kv.getValue()/kv.getWeight();
break;
default:
throw new DMLRuntimeException("Invalid aggreagte in CM_CV_Object: " + agg);
}
}
else
{
//avoid division by 0
val = kv.getValue()/kv.getWeight();
}
return new MatrixCell(val);
}
}
public static class RDDTransformApplyFunction implements PairFunction,Long,FrameBlock>
{
private static final long serialVersionUID = 5759813006068230916L;
private Broadcast _bencoder = null;
private Broadcast _omap = null;
public RDDTransformApplyFunction(Broadcast bencoder, Broadcast omap) {
_bencoder = bencoder;
_omap = omap;
}
@Override
public Tuple2 call(Tuple2 in)
throws Exception
{
long key = in._1();
FrameBlock blk = in._2();
//execute block transform apply
Encoder encoder = _bencoder.getValue();
MatrixBlock tmp = encoder.apply(blk, new MatrixBlock(blk.getNumRows(), blk.getNumColumns(), false));
//remap keys
if( _omap != null ) {
key = _omap.getValue().getOffset(key);
}
//convert to frameblock to reuse frame-matrix reblock
return new Tuple2<>(key,
DataConverter.convertToFrameBlock(tmp));
}
}
public static class RDDTransformApplyOffsetFunction implements PairFunction,Long,Long>
{
private static final long serialVersionUID = 3450977356721057440L;
private int[] _omitColList = null;
public RDDTransformApplyOffsetFunction(String spec, String[] colnames) {
try {
_omitColList = TfMetaUtils.parseJsonIDList(spec, colnames, TfUtils.TXMETHOD_OMIT);
}
catch (DMLRuntimeException e) {
throw new RuntimeException(e);
}
}
@Override
public Tuple2 call(Tuple2 in)
throws Exception
{
long key = in._1();
long rmRows = 0;
FrameBlock blk = in._2();
for( int i=0; i(key, rmRows);
}
}
public static class RDDTransformDecodeFunction implements PairFunction,Long,FrameBlock>
{
private static final long serialVersionUID = -4797324742568170756L;
private Decoder _decoder = null;
private int _brlen = -1;
public RDDTransformDecodeFunction(Decoder decoder, int brlen) {
_decoder = decoder;
_brlen = brlen;
}
@Override
public Tuple2 call(Tuple2 in)
throws Exception
{
long rix = UtilFunctions.computeCellIndex(in._1().getRowIndex(), _brlen, 0);
FrameBlock fbout = _decoder.decode(in._2(), new FrameBlock(_decoder.getSchema()));
fbout.setColumnNames(Arrays.copyOfRange(_decoder.getColnames(), 0, fbout.getNumColumns()));
return new Tuple2<>(rix, fbout);
}
}
public static class RDDTransformDecodeExpandFunction implements PairFunction,MatrixIndexes,MatrixBlock>
{
private static final long serialVersionUID = -8187400248076127598L;
private int _clen = -1;
private int _bclen = -1;
public RDDTransformDecodeExpandFunction(int clen, int bclen) {
_clen = clen;
_bclen = bclen;
}
@Override
public Tuple2 call(Tuple2 in)
throws Exception
{
MatrixIndexes inIx = in._1();
MatrixBlock inBlk = in._2();
//construct expanded block via leftindexing
int cl = (int)UtilFunctions.computeCellIndex(inIx.getColumnIndex(), _bclen, 0)-1;
int cu = (int)UtilFunctions.computeCellIndex(inIx.getColumnIndex(), _bclen, inBlk.getNumColumns()-1)-1;
MatrixBlock out = new MatrixBlock(inBlk.getNumRows(), _clen, false);
out = out.leftIndexingOperations(inBlk, 0, inBlk.getNumRows()-1, cl, cu, null, UpdateType.INPLACE_PINNED);
return new Tuple2<>(new MatrixIndexes(inIx.getRowIndex(), 1), out);
}
}
public void setOutputCharacteristicsForGroupedAgg(MatrixCharacteristics mc1, MatrixCharacteristics mcOut, JavaPairRDD out)
throws DMLRuntimeException
{
if(!mcOut.dimsKnown()) {
if(!mc1.dimsKnown()) {
throw new DMLRuntimeException("The output dimensions are not specified for grouped aggregate");
}
if ( params.get(Statement.GAGG_NUM_GROUPS) != null) {
int ngroups = (int) Double.parseDouble(params.get(Statement.GAGG_NUM_GROUPS));
mcOut.set(ngroups, mc1.getCols(), -1, -1); //grouped aggregate with cell output
}
else {
out = SparkUtils.cacheBinaryCellRDD(out);
mcOut.set(SparkUtils.computeMatrixCharacteristics(out));
mcOut.setBlockSize(-1, -1); //grouped aggregate with cell output
}
}
}
}