org.apache.sysml.runtime.instructions.spark.QuaternarySPInstruction Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.instructions.spark;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import scala.Tuple2;
import org.apache.sysml.lops.WeightedCrossEntropy;
import org.apache.sysml.lops.WeightedDivMM;
import org.apache.sysml.lops.WeightedDivMM.WDivMMType;
import org.apache.sysml.lops.WeightedSigmoid;
import org.apache.sysml.lops.WeightedSquaredLoss;
import org.apache.sysml.lops.WeightedSquaredLossR;
import org.apache.sysml.lops.WeightedSigmoid.WSigmoidType;
import org.apache.sysml.lops.WeightedSquaredLoss.WeightsType;
import org.apache.sysml.lops.WeightedCrossEntropy.WCeMMType;
import org.apache.sysml.lops.WeightedUnaryMM;
import org.apache.sysml.lops.WeightedUnaryMM.WUMMType;
import org.apache.sysml.lops.WeightedUnaryMMR;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.DoubleObject;
import org.apache.sysml.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcastMatrix;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.QuaternaryOperator;
/**
*
*/
public class QuaternarySPInstruction extends ComputationSPInstruction
{
private CPOperand _input4 = null;
private boolean _cacheU = false;
private boolean _cacheV = false;
public QuaternarySPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, boolean cacheU, boolean cacheV, String opcode, String str)
{
super(op, in1, in2, in3, out, opcode, str);
_sptype = SPINSTRUCTION_TYPE.Quaternary;
_input4 = in4;
_cacheU = cacheU;
_cacheV = cacheV;
}
/**
*
* @param str
* @return
* @throws DMLRuntimeException
*/
public static QuaternarySPInstruction parseInstruction( String str )
throws DMLRuntimeException
{
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
//validity check
if( !InstructionUtils.isDistQuaternaryOpcode(opcode) ) {
throw new DMLRuntimeException("Quaternary.parseInstruction():: Unknown opcode " + opcode);
}
//instruction parsing
if( WeightedSquaredLoss.OPCODE.equalsIgnoreCase(opcode) //wsloss
|| WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode) )
{
boolean isRed = WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode);
//check number of fields (4 inputs, output, type)
if( isRed )
InstructionUtils.checkNumFields ( parts, 8 );
else
InstructionUtils.checkNumFields ( parts, 6 );
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand in4 = new CPOperand(parts[4]);
CPOperand out = new CPOperand(parts[5]);
WeightsType wtype = WeightsType.valueOf(parts[6]);
//in mappers always through distcache, in reducers through distcache/shuffle
boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
return new QuaternarySPInstruction(new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
}
else if( WeightedUnaryMM.OPCODE.equalsIgnoreCase(opcode) //wumm
|| WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode) )
{
boolean isRed = WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode);
//check number of fields (4 inputs, output, type)
if( isRed )
InstructionUtils.checkNumFields ( parts, 8 );
else
InstructionUtils.checkNumFields ( parts, 6 );
String uopcode = parts[1];
CPOperand in1 = new CPOperand(parts[2]);
CPOperand in2 = new CPOperand(parts[3]);
CPOperand in3 = new CPOperand(parts[4]);
CPOperand out = new CPOperand(parts[5]);
WUMMType wtype = WUMMType.valueOf(parts[6]);
//in mappers always through distcache, in reducers through distcache/shuffle
boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;
return new QuaternarySPInstruction(new QuaternaryOperator(wtype, uopcode), in1, in2, in3, null, out, cacheU, cacheV, opcode, str);
}
else //map/redwsigmoid, map/redwdivmm, map/redwcemm
{
boolean isRed = opcode.startsWith("red");
//check number of fields (3 inputs, output, type)
if( isRed )
InstructionUtils.checkNumFields( parts, 7 );
else
InstructionUtils.checkNumFields( parts, 5 );
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand out = new CPOperand(parts[4]);
//in mappers always through distcache, in reducers through distcache/shuffle
boolean cacheU = isRed ? Boolean.parseBoolean(parts[6]) : true;
boolean cacheV = isRed ? Boolean.parseBoolean(parts[7]) : true;
if( opcode.endsWith("wsigmoid") )
return new QuaternarySPInstruction(new QuaternaryOperator(WSigmoidType.valueOf(parts[5])), in1, in2, in3, null, out, cacheU, cacheV, opcode, str);
else if( opcode.endsWith("wdivmm") )
return new QuaternarySPInstruction(new QuaternaryOperator(WDivMMType.valueOf(parts[5])), in1, in2, in3, null, out, cacheU, cacheV, opcode, str);
else if( opcode.endsWith("wcemm") )
return new QuaternarySPInstruction(new QuaternaryOperator(WCeMMType.valueOf(parts[5])), in1, in2, in3, null, out, cacheU, cacheV, opcode, str);
}
return null;
}
@Override
public void processInstruction(ExecutionContext ec)
throws DMLRuntimeException, DMLUnsupportedOperationException
{
SparkExecutionContext sec = (SparkExecutionContext)ec;
QuaternaryOperator qop = (QuaternaryOperator) _optr;
//tracking of rdds and broadcasts (for lineage maintenance)
ArrayList rddVars = new ArrayList();
ArrayList bcVars = new ArrayList();
JavaPairRDD in = sec.getBinaryBlockRDDHandleForVariable( input1.getName() );
JavaPairRDD out = null;
MatrixCharacteristics inMc = sec.getMatrixCharacteristics( input1.getName() );
long rlen = inMc.getRows();
long clen = inMc.getCols();
int brlen = inMc.getRowsPerBlock();
int bclen = inMc.getColsPerBlock();
//map-side only operation (one rdd input, two broadcasts)
if( WeightedSquaredLoss.OPCODE.equalsIgnoreCase(getOpcode())
|| WeightedSigmoid.OPCODE.equalsIgnoreCase(getOpcode())
|| WeightedDivMM.OPCODE.equalsIgnoreCase(getOpcode())
|| WeightedCrossEntropy.OPCODE.equalsIgnoreCase(getOpcode())
|| WeightedUnaryMM.OPCODE.equalsIgnoreCase(getOpcode()))
{
PartitionedBroadcastMatrix bc1 = sec.getBroadcastForVariable( input2.getName() );
PartitionedBroadcastMatrix bc2 = sec.getBroadcastForVariable( input3.getName() );
//partitioning-preserving mappartitions (key access required for broadcast loopkup)
boolean noKeyChange = (qop.wtype3 == null || qop.wtype3.isBasic()); //only wsdivmm changes keys
out = in.mapPartitionsToPair(new RDDQuaternaryFunction1(qop, bc1, bc2), noKeyChange);
rddVars.add( input1.getName() );
bcVars.add( input2.getName() );
bcVars.add( input3.getName() );
}
//reduce-side operation (two/three/four rdd inputs, zero/one/two broadcasts)
else
{
PartitionedBroadcastMatrix bc1 = _cacheU ? sec.getBroadcastForVariable( input2.getName() ) : null;
PartitionedBroadcastMatrix bc2 = _cacheV ? sec.getBroadcastForVariable( input3.getName() ) : null;
JavaPairRDD inU = (!_cacheU) ? sec.getBinaryBlockRDDHandleForVariable( input2.getName() ) : null;
JavaPairRDD inV = (!_cacheV) ? sec.getBinaryBlockRDDHandleForVariable( input3.getName() ) : null;
JavaPairRDD inW = (qop.wtype1!=null && qop.wtype1.hasFourInputs()) ?
sec.getBinaryBlockRDDHandleForVariable( _input4.getName() ) : null;
//preparation of transposed and replicated U
if( inU != null )
inU = inU.flatMapToPair(new ReplicateBlocksFunction(clen, bclen, true));
//preparation of transposed and replicated V
if( inV != null )
inV = inV.mapToPair(new TransposeFactorIndexesFunction())
.flatMapToPair(new ReplicateBlocksFunction(rlen, brlen, false));
//functions calls w/ two rdd inputs
if( inU != null && inV == null && inW == null )
out = in.join(inU)
.mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
else if( inU == null && inV != null && inW == null )
out = in.join(inV)
.mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
else if( inU == null && inV == null && inW != null )
out = in.join(inW)
.mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2));
//function calls w/ three rdd inputs
else if( inU != null && inV != null && inW == null )
out = in.join(inU).join(inV)
.mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
else if( inU != null && inV == null && inW != null )
out = in.join(inU).join(inW)
.mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
else if( inU == null && inV != null && inW != null )
out = in.join(inV).join(inW)
.mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2));
//function call w/ four rdd inputs
else
out = in.join(inU).join(inV).join(inW)
.mapValues(new RDDQuaternaryFunction4(qop));
//keep variable names for lineage maintenance
if( inU == null ) bcVars.add(input2.getName()); else rddVars.add(input2.getName());
if( inV == null ) bcVars.add(input3.getName()); else rddVars.add(input3.getName());
if( inW != null ) rddVars.add(_input4.getName());
}
//output handling, incl aggregation
if( qop.wtype1 != null || qop.wtype4 != null ) //map/redwsloss, map/redwcemm
{
//full aggregate and cast to scalar
MatrixBlock tmp = RDDAggregateUtils.sumStable(out);
DoubleObject ret = new DoubleObject(tmp.getValue(0, 0));
sec.setVariable(output.getName(), ret);
}
else //map/redwsigmoid, map/redwdivmm, map/redwumm
{
//aggregation if required (map/redwdivmm)
if( qop.wtype3 != null && !qop.wtype3.isBasic() )
out = RDDAggregateUtils.sumByKeyStable( out );
//put output RDD handle into symbol table
sec.setRDDHandleForVariable(output.getName(), out);
//maintain lineage information for output rdd
for( String rddVar : rddVars )
sec.addLineageRDD(output.getName(), rddVar);
for( String bcVar : bcVars )
sec.addLineageBroadcast(output.getName(), bcVar);
//update matrix characteristics
updateOutputMatrixCharacteristics(sec, qop);
}
}
/**
*
* @param sec
* @param qop
* @throws DMLRuntimeException
*/
private void updateOutputMatrixCharacteristics(SparkExecutionContext sec, QuaternaryOperator qop)
throws DMLRuntimeException
{
MatrixCharacteristics mcIn1 = sec.getMatrixCharacteristics(input1.getName());
MatrixCharacteristics mcIn2 = sec.getMatrixCharacteristics(input2.getName());
MatrixCharacteristics mcIn3 = sec.getMatrixCharacteristics(input3.getName());
MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
if( qop.wtype2 != null || qop.wtype5 != null ) {
//output size determined by main input
mcOut.set(mcIn1.getRows(), mcIn1.getCols(), mcIn1.getRowsPerBlock(), mcIn1.getColsPerBlock());
}
else if(qop.wtype3 != null ) { //wdivmm
long rank = qop.wtype3.isLeft() ? mcIn3.getCols() : mcIn2.getCols();
MatrixCharacteristics mcTmp = qop.wtype3.computeOutputCharacteristics(mcIn1.getRows(), mcIn1.getCols(), rank);
mcOut.set(mcTmp.getRows(), mcTmp.getCols(), mcIn1.getRowsPerBlock(), mcIn1.getColsPerBlock());
}
}
/**
*
*/
private abstract static class RDDQuaternaryBaseFunction implements Serializable
{
private static final long serialVersionUID = -3175397651350954930L;
protected QuaternaryOperator _qop = null;
protected PartitionedBroadcastMatrix _pmU = null;
protected PartitionedBroadcastMatrix _pmV = null;
public RDDQuaternaryBaseFunction( QuaternaryOperator qop, PartitionedBroadcastMatrix bcU, PartitionedBroadcastMatrix bcV ) {
_qop = qop;
_pmU = bcU;
_pmV = bcV;
}
protected MatrixIndexes createOutputIndexes(MatrixIndexes in)
{
if( _qop.wtype3 != null && !_qop.wtype3.isBasic() ){ //key change
boolean left = _qop.wtype3.isLeft();
return new MatrixIndexes(left?in.getColumnIndex():in.getRowIndex(), 1);
}
return in;
}
}
/**
*
*/
private static class RDDQuaternaryFunction1 extends RDDQuaternaryBaseFunction //one rdd input
implements PairFlatMapFunction>, MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID = -8209188316939435099L;
public RDDQuaternaryFunction1( QuaternaryOperator qop, PartitionedBroadcastMatrix bcU, PartitionedBroadcastMatrix bcV )
throws DMLRuntimeException, DMLUnsupportedOperationException
{
super(qop, bcU, bcV);
}
@Override
public Iterable> call(Iterator> arg)
throws Exception
{
return new RDDQuaternaryPartitionIterator(arg);
}
private class RDDQuaternaryPartitionIterator extends LazyIterableIterator>
{
public RDDQuaternaryPartitionIterator(Iterator> in) {
super(in);
}
@Override
protected Tuple2 computeNext(Tuple2 arg)
throws Exception
{
MatrixIndexes ixIn = arg._1();
MatrixBlock blkIn = arg._2();
MatrixBlock blkOut = new MatrixBlock();
MatrixBlock mbU = _pmU.getMatrixBlock((int)ixIn.getRowIndex(), 1);
MatrixBlock mbV = _pmV.getMatrixBlock((int)ixIn.getColumnIndex(), 1);
//execute core operation
blkIn.quaternaryOperations(_qop, mbU, mbV, null, blkOut);
//create return tuple
MatrixIndexes ixOut = createOutputIndexes(ixIn);
return new Tuple2(ixOut, blkOut);
}
}
}
/**
*
*/
private static class RDDQuaternaryFunction2 extends RDDQuaternaryBaseFunction //two rdd input
implements PairFunction>, MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID = 7493974462943080693L;
public RDDQuaternaryFunction2( QuaternaryOperator qop, PartitionedBroadcastMatrix bcU, PartitionedBroadcastMatrix bcV )
throws DMLRuntimeException, DMLUnsupportedOperationException
{
super(qop, bcU, bcV);
}
@Override
public Tuple2 call(Tuple2> arg0)
throws Exception
{
MatrixIndexes ixIn = arg0._1();
MatrixBlock blkIn1 = arg0._2()._1();
MatrixBlock blkIn2 = arg0._2()._2();
MatrixBlock blkOut = new MatrixBlock();
MatrixBlock mbU = (_pmU!=null)?_pmU.getMatrixBlock((int)ixIn.getRowIndex(), 1) : blkIn2;
MatrixBlock mbV = (_pmV!=null)?_pmV.getMatrixBlock((int)ixIn.getColumnIndex(), 1) : blkIn2;
MatrixBlock mbW = (_qop.wtype1!=null && _qop.wtype1.hasFourInputs()) ? blkIn2 : null;
//execute core operation
blkIn1.quaternaryOperations(_qop, mbU, mbV, mbW, blkOut);
//create return tuple
MatrixIndexes ixOut = createOutputIndexes(ixIn);
return new Tuple2(ixOut, blkOut);
}
}
/**
*
*/
private static class RDDQuaternaryFunction3 extends RDDQuaternaryBaseFunction //three rdd input
implements PairFunction,MatrixBlock>>, MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID = -2294086455843773095L;
public RDDQuaternaryFunction3( QuaternaryOperator qop, PartitionedBroadcastMatrix bcU, PartitionedBroadcastMatrix bcV )
throws DMLRuntimeException, DMLUnsupportedOperationException
{
super(qop, bcU, bcV);
}
@Override
public Tuple2 call(Tuple2, MatrixBlock>> arg0)
throws Exception
{
MatrixIndexes ixIn = arg0._1();
MatrixBlock blkIn1 = arg0._2()._1()._1();
MatrixBlock blkIn2 = arg0._2()._1()._2();
MatrixBlock blkIn3 = arg0._2()._2();
MatrixBlock blkOut = new MatrixBlock();
MatrixBlock mbU = (_pmU!=null)?_pmU.getMatrixBlock((int)ixIn.getRowIndex(), 1) : blkIn2;
MatrixBlock mbV = (_pmV!=null)?_pmV.getMatrixBlock((int)ixIn.getColumnIndex(), 1) :
(_pmU!=null)? blkIn2 : blkIn3;
MatrixBlock mbW = (_qop.wtype1!=null && _qop.wtype1.hasFourInputs())? blkIn3 : null;
//execute core operation
blkIn1.quaternaryOperations(_qop, mbU, mbV, mbW, blkOut);
//create return tuple
MatrixIndexes ixOut = createOutputIndexes(ixIn);
return new Tuple2(ixOut, blkOut);
}
}
/**
* Note: never called for wsigmoid/wdivmm (only wsloss)
*/
private static class RDDQuaternaryFunction4 extends RDDQuaternaryBaseFunction //four rdd input
implements Function,MatrixBlock>,MatrixBlock>,MatrixBlock>
{
private static final long serialVersionUID = 7328911771600289250L;
public RDDQuaternaryFunction4( QuaternaryOperator qop )
throws DMLRuntimeException, DMLUnsupportedOperationException
{
super(qop, null, null);
}
@Override
public MatrixBlock call(Tuple2, MatrixBlock>, MatrixBlock> arg0)
throws Exception
{
MatrixBlock blkIn1 = arg0._1()._1()._1();
MatrixBlock mbU = arg0._1()._1()._2();
MatrixBlock mbV = arg0._1()._2();
MatrixBlock mbW = arg0._2();
MatrixBlock blkOut = new MatrixBlock();
//execute core operation
blkIn1.quaternaryOperations(_qop, mbU, mbV, mbW, blkOut);
//create return tuple
return blkOut;
}
}
private static class TransposeFactorIndexesFunction implements PairFunction, MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID = -2571724736131823708L;
@Override
public Tuple2 call( Tuple2 arg0 )
throws Exception
{
MatrixIndexes ixIn = arg0._1();
MatrixBlock blkIn = arg0._2();
//swap the matrix indexes
MatrixIndexes ixOut = new MatrixIndexes(ixIn.getColumnIndex(), ixIn.getRowIndex());
MatrixBlock blkOut = new MatrixBlock(blkIn);
//output new tuple
return new Tuple2(ixOut,blkOut);
}
}
private static class ReplicateBlocksFunction implements PairFlatMapFunction, MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID = -1184696764516975609L;
private long _len = -1;
private long _blen = -1;
private boolean _left = false;
public ReplicateBlocksFunction(long len, long blen, boolean left)
{
_len = len;
_blen = blen;
_left = left;
}
@Override
public Iterable> call( Tuple2 arg0 )
throws Exception
{
LinkedList> ret = new LinkedList>();
MatrixIndexes ixIn = arg0._1();
MatrixBlock blkIn = arg0._2();
long numBlocks = (long) Math.ceil((double)_len/_blen);
if( _left ) //LHS MATRIX
{
//replicate wrt # column blocks in RHS
long i = ixIn.getRowIndex();
for( long j=1; j<=numBlocks; j++ ) {
MatrixIndexes tmpix = new MatrixIndexes(i, j);
MatrixBlock tmpblk = new MatrixBlock(blkIn);
ret.add( new Tuple2(tmpix, tmpblk) );
}
}
else // RHS MATRIX
{
//replicate wrt # row blocks in LHS
long j = ixIn.getColumnIndex();
for( long i=1; i<=numBlocks; i++ ) {
MatrixIndexes tmpix = new MatrixIndexes(i, j);
MatrixBlock tmpblk = new MatrixBlock(blkIn);
ret.add( new Tuple2(tmpix, tmpblk) );
}
}
//output list of new tuples
return ret;
}
}
}