org.apache.sysml.runtime.instructions.spark.TernarySPInstruction Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.runtime.instructions.spark;
import java.util.ArrayList;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import scala.Tuple2;
import org.apache.sysml.lops.Ternary;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.CTable;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.CTableMap;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixCell;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.data.Pair;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.SimpleOperator;
import org.apache.sysml.runtime.util.LongLongDoubleHashMap.LLDoubleEntry;
import org.apache.sysml.runtime.util.UtilFunctions;
public class TernarySPInstruction extends ComputationSPInstruction
{
private String _outDim1;
private String _outDim2;
private boolean _dim1Literal;
private boolean _dim2Literal;
private boolean _isExpand;
private boolean _ignoreZeros;
public TernarySPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out,
String outputDim1, boolean dim1Literal,String outputDim2, boolean dim2Literal,
boolean isExpand, boolean ignoreZeros, String opcode, String istr )
{
super(op, in1, in2, in3, out, opcode, istr);
_outDim1 = outputDim1;
_dim1Literal = dim1Literal;
_outDim2 = outputDim2;
_dim2Literal = dim2Literal;
_isExpand = isExpand;
_ignoreZeros = ignoreZeros;
}
public static TernarySPInstruction parseInstruction(String inst)
throws DMLRuntimeException
{
String[] parts = InstructionUtils.getInstructionPartsWithValueType(inst);
InstructionUtils.checkNumFields ( parts, 7 );
String opcode = parts[0];
//handle opcode
if ( !(opcode.equalsIgnoreCase("ctable") || opcode.equalsIgnoreCase("ctableexpand")) ) {
throw new DMLRuntimeException("Unexpected opcode in TertiarySPInstruction: " + inst);
}
boolean isExpand = opcode.equalsIgnoreCase("ctableexpand");
//handle operands
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
//handle known dimension information
String[] dim1Fields = parts[4].split(Instruction.LITERAL_PREFIX);
String[] dim2Fields = parts[5].split(Instruction.LITERAL_PREFIX);
CPOperand out = new CPOperand(parts[6]);
boolean ignoreZeros = Boolean.parseBoolean(parts[7]);
// ctable does not require any operator, so we simply pass-in a dummy operator with null functionobject
return new TernarySPInstruction(new SimpleOperator(null), in1, in2, in3, out, dim1Fields[0], Boolean.parseBoolean(dim1Fields[1]), dim2Fields[0], Boolean.parseBoolean(dim2Fields[1]), isExpand, ignoreZeros, opcode, inst);
}
@Override
public void processInstruction(ExecutionContext ec)
throws DMLRuntimeException
{
SparkExecutionContext sec = (SparkExecutionContext)ec;
//get input rdd handle
JavaPairRDD in1 = sec.getBinaryBlockRDDHandleForVariable( input1.getName() );
JavaPairRDD in2 = null;
JavaPairRDD in3 = null;
double scalar_input2 = -1, scalar_input3 = -1;
Ternary.OperationTypes ctableOp = Ternary.findCtableOperationByInputDataTypes(
input1.getDataType(), input2.getDataType(), input3.getDataType());
ctableOp = _isExpand ? Ternary.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ctableOp;
MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(input1.getName());
MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
// First get the block sizes and then set them as -1 to allow for binary cell reblock
int brlen = mc1.getRowsPerBlock();
int bclen = mc1.getColsPerBlock();
JavaPairRDD> inputMBs = null;
JavaPairRDD ctables = null;
JavaPairRDD bincellsNoFilter = null;
boolean setLineage2 = false;
boolean setLineage3 = false;
switch(ctableOp) {
case CTABLE_TRANSFORM: //(VECTOR)
// F=ctable(A,B,W)
in2 = sec.getBinaryBlockRDDHandleForVariable( input2.getName() );
in3 = sec.getBinaryBlockRDDHandleForVariable( input3.getName() );
setLineage2 = true;
setLineage3 = true;
inputMBs = in1.cogroup(in2).cogroup(in3)
.mapToPair(new MapThreeMBIterableIntoAL());
ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2,
scalar_input3, this.instString, (SimpleOperator)_optr, _ignoreZeros));
break;
case CTABLE_EXPAND_SCALAR_WEIGHT: //(VECTOR)
// F = ctable(seq,A) or F = ctable(seq,B,1)
scalar_input3 = sec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
if(scalar_input3 == 1) {
in2 = sec.getBinaryBlockRDDHandleForVariable( input2.getName() );
setLineage2 = true;
bincellsNoFilter = in2.flatMapToPair(new ExpandScalarCtableOperation(brlen));
break;
}
case CTABLE_TRANSFORM_SCALAR_WEIGHT: //(VECTOR/MATRIX)
// F = ctable(A,B) or F = ctable(A,B,1)
in2 = sec.getBinaryBlockRDDHandleForVariable( input2.getName() );
setLineage2 = true;
scalar_input3 = sec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
inputMBs = in1.cogroup(in2).mapToPair(new MapTwoMBIterableIntoAL());
ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2,
scalar_input3, this.instString, (SimpleOperator)_optr, _ignoreZeros));
break;
case CTABLE_TRANSFORM_HISTOGRAM: //(VECTOR)
// F=ctable(A,1) or F = ctable(A,1,1)
scalar_input2 = sec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue();
scalar_input3 = sec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue();
inputMBs = in1.mapToPair(new MapMBIntoAL());
ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2,
scalar_input3, this.instString, (SimpleOperator)_optr, _ignoreZeros));
break;
case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: //(VECTOR)
// F=ctable(A,1,W)
in3 = sec.getBinaryBlockRDDHandleForVariable( input3.getName() );
setLineage3 = true;
scalar_input2 = sec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue();
inputMBs = in1.cogroup(in3).mapToPair(new MapTwoMBIterableIntoAL());
ctables = inputMBs.mapToPair(new PerformCTableMapSideOperation(ctableOp, scalar_input2,
scalar_input3, this.instString, (SimpleOperator)_optr, _ignoreZeros));
break;
default:
throw new DMLRuntimeException("Encountered an invalid ctable operation ("+ctableOp+") while executing instruction: " + this.toString());
}
// Now perform aggregation on ctables to get binaryCells
if(bincellsNoFilter == null && ctables != null) {
bincellsNoFilter =
ctables.values()
.flatMapToPair(new ExtractBinaryCellsFromCTable());
bincellsNoFilter = RDDAggregateUtils.sumCellsByKeyStable(bincellsNoFilter);
}
else if(!(bincellsNoFilter != null && ctables == null)) {
throw new DMLRuntimeException("Incorrect ctable operation");
}
// handle known/unknown dimensions
long outputDim1 = (_dim1Literal ? (long) Double.parseDouble(_outDim1) : (sec.getScalarInput(_outDim1, ValueType.DOUBLE, false)).getLongValue());
long outputDim2 = (_dim2Literal ? (long) Double.parseDouble(_outDim2) : (sec.getScalarInput(_outDim2, ValueType.DOUBLE, false)).getLongValue());
MatrixCharacteristics mcBinaryCells = null;
boolean findDimensions = (outputDim1 == -1 && outputDim2 == -1);
if( !findDimensions ) {
if((outputDim1 == -1 && outputDim2 != -1) || (outputDim1 != -1 && outputDim2 == -1))
throw new DMLRuntimeException("Incorrect output dimensions passed to TernarySPInstruction:" + outputDim1 + " " + outputDim2);
else
mcBinaryCells = new MatrixCharacteristics(outputDim1, outputDim2, brlen, bclen);
// filtering according to given dimensions
bincellsNoFilter = bincellsNoFilter
.filter(new FilterCells(mcBinaryCells.getRows(), mcBinaryCells.getCols()));
}
// convert double values to matrix cell
JavaPairRDD binaryCells = bincellsNoFilter
.mapToPair(new ConvertToBinaryCell());
// find dimensions if necessary (w/ cache for reblock)
if( findDimensions ) {
binaryCells = SparkUtils.cacheBinaryCellRDD(binaryCells);
mcBinaryCells = SparkUtils.computeMatrixCharacteristics(binaryCells);
}
//store output rdd handle
sec.setRDDHandleForVariable(output.getName(), binaryCells);
mcOut.set(mcBinaryCells);
// Since we are outputing binary cells, we set block sizes = -1
mcOut.setRowsPerBlock(-1); mcOut.setColsPerBlock(-1);
sec.addLineageRDD(output.getName(), input1.getName());
if(setLineage2)
sec.addLineageRDD(output.getName(), input2.getName());
if(setLineage3)
sec.addLineageRDD(output.getName(), input3.getName());
}
private static class ExpandScalarCtableOperation implements PairFlatMapFunction, MatrixIndexes, Double>
{
private static final long serialVersionUID = -12552669148928288L;
private int _brlen;
public ExpandScalarCtableOperation(int brlen) {
_brlen = brlen;
}
@Override
public Iterable> call(Tuple2 arg0)
throws Exception
{
MatrixIndexes ix = arg0._1();
MatrixBlock mb = arg0._2(); //col-vector
//create an output cell per matrix block row (aligned w/ original source position)
ArrayList> retVal = new ArrayList>();
CTable ctab = CTable.getCTableFnObject();
for( int i=0; i p = ctab.execute(row, v2, 1.0);
//indirect construction over pair to avoid tuple2 dependency in general ctable obj
if( p.getKey().getRowIndex() >= 1 ) //filter rejected entries
retVal.add(new Tuple2(p.getKey(), p.getValue()));
}
return retVal;
}
}
private static class MapTwoMBIterableIntoAL implements PairFunction,Iterable>>, MatrixIndexes, ArrayList> {
private static final long serialVersionUID = 271459913267735850L;
private MatrixBlock extractBlock(Iterable blks, MatrixBlock retVal) throws Exception {
for(MatrixBlock blk1 : blks) {
if(retVal != null) {
throw new Exception("ERROR: More than 1 matrixblock found for one of the inputs at a given index");
}
retVal = blk1;
}
if(retVal == null) {
throw new Exception("ERROR: No matrixblock found for one of the inputs at a given index");
}
return retVal;
}
@Override
public Tuple2> call(
Tuple2, Iterable>> kv)
throws Exception {
MatrixBlock in1 = null; MatrixBlock in2 = null;
in1 = extractBlock(kv._2._1, in1);
in2 = extractBlock(kv._2._2, in2);
// Now return unflatten AL
ArrayList inputs = new ArrayList();
inputs.add(in1); inputs.add(in2);
return new Tuple2>(kv._1, inputs);
}
}
private static class MapThreeMBIterableIntoAL implements PairFunction,Iterable>>,Iterable>>, MatrixIndexes, ArrayList> {
private static final long serialVersionUID = -4873754507037646974L;
private MatrixBlock extractBlock(Iterable blks, MatrixBlock retVal) throws Exception {
for(MatrixBlock blk1 : blks) {
if(retVal != null) {
throw new Exception("ERROR: More than 1 matrixblock found for one of the inputs at a given index");
}
retVal = blk1;
}
if(retVal == null) {
throw new Exception("ERROR: No matrixblock found for one of the inputs at a given index");
}
return retVal;
}
@Override
public Tuple2> call(
Tuple2, Iterable>>, Iterable>> kv)
throws Exception {
MatrixBlock in1 = null; MatrixBlock in2 = null; MatrixBlock in3 = null;
for(Tuple2, Iterable> blks : kv._2._1) {
in1 = extractBlock(blks._1, in1);
in2 = extractBlock(blks._2, in2);
}
in3 = extractBlock(kv._2._2, in3);
// Now return unflatten AL
ArrayList inputs = new ArrayList();
inputs.add(in1); inputs.add(in2); inputs.add(in3);
return new Tuple2>(kv._1, inputs);
}
}
private static class PerformCTableMapSideOperation implements PairFunction>, MatrixIndexes, CTableMap> {
private static final long serialVersionUID = 5348127596473232337L;
Ternary.OperationTypes ctableOp;
double scalar_input2; double scalar_input3;
String instString;
Operator optr;
boolean ignoreZeros;
public PerformCTableMapSideOperation(Ternary.OperationTypes ctableOp, double scalar_input2, double scalar_input3, String instString, Operator optr, boolean ignoreZeros) {
this.ctableOp = ctableOp;
this.scalar_input2 = scalar_input2;
this.scalar_input3 = scalar_input3;
this.instString = instString;
this.optr = optr;
this.ignoreZeros = ignoreZeros;
}
private void expectedALSize(int length, ArrayList al) throws Exception {
if(al.size() != length) {
throw new Exception("Expected arraylist of size:" + length + ", but found " + al.size());
}
}
@Override
public Tuple2 call(
Tuple2> kv) throws Exception {
CTableMap ctableResult = new CTableMap();
MatrixBlock ctableResultBlock = null;
IndexedMatrixValue in1, in2, in3 = null;
in1 = new IndexedMatrixValue(kv._1, kv._2.get(0));
MatrixBlock matBlock1 = kv._2.get(0);
switch( ctableOp )
{
case CTABLE_TRANSFORM: {
in2 = new IndexedMatrixValue(kv._1, kv._2.get(1));
in3 = new IndexedMatrixValue(kv._1, kv._2.get(2));
expectedALSize(3, kv._2);
if(in1==null || in2==null || in3 == null )
break;
else
OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), in2.getIndexes(), in2.getValue(),
in3.getIndexes(), in3.getValue(), ctableResult, ctableResultBlock, optr);
break;
}
case CTABLE_TRANSFORM_SCALAR_WEIGHT:
case CTABLE_EXPAND_SCALAR_WEIGHT:
{
// 3rd input is a scalar
in2 = new IndexedMatrixValue(kv._1, kv._2.get(1));
expectedALSize(2, kv._2);
if(in1==null || in2==null )
break;
else
matBlock1.ternaryOperations((SimpleOperator)optr, kv._2.get(1), scalar_input3, ignoreZeros, ctableResult, ctableResultBlock);
break;
}
case CTABLE_TRANSFORM_HISTOGRAM: {
expectedALSize(1, kv._2);
OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), scalar_input2,
scalar_input3, ctableResult, ctableResultBlock, optr);
break;
}
case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: {
// 2nd and 3rd inputs are scalars
expectedALSize(2, kv._2);
in3 = new IndexedMatrixValue(kv._1, kv._2.get(1)); // Note: kv._2.get(1), not kv._2.get(2)
if(in1==null || in3==null)
break;
else
OperationsOnMatrixValues.performTernary(in1.getIndexes(), in1.getValue(), scalar_input2,
in3.getIndexes(), in3.getValue(), ctableResult, ctableResultBlock, optr);
break;
}
default:
throw new DMLRuntimeException("Unrecognized opcode in Tertiary Instruction: " + instString);
}
return new Tuple2(kv._1, ctableResult);
}
}
private static class MapMBIntoAL implements PairFunction, MatrixIndexes, ArrayList> {
private static final long serialVersionUID = 2068398913653350125L;
@Override
public Tuple2> call(
Tuple2 kv) throws Exception {
ArrayList retVal = new ArrayList();
retVal.add(kv._2);
return new Tuple2>(kv._1, retVal);
}
}
private static class ExtractBinaryCellsFromCTable implements PairFlatMapFunction {
private static final long serialVersionUID = -5933677686766674444L;
@SuppressWarnings("deprecation")
@Override
public Iterable> call(CTableMap ctableMap)
throws Exception {
ArrayList> retVal = new ArrayList>();
for(LLDoubleEntry ijv : ctableMap.entrySet()) {
long i = ijv.key1;
long j = ijv.key2;
double v = ijv.value;
// retVal.add(new Tuple2(blockIndexes, cell));
retVal.add(new Tuple2(new MatrixIndexes(i, j), v));
}
return retVal;
}
}
private static class ConvertToBinaryCell implements PairFunction, MatrixIndexes, MatrixCell> {
private static final long serialVersionUID = 7481186480851982800L;
@Override
public Tuple2 call(
Tuple2 kv) throws Exception {
MatrixCell cell = new MatrixCell(kv._2().doubleValue());
return new Tuple2(kv._1(), cell);
}
}
private static class FilterCells implements Function, Boolean> {
private static final long serialVersionUID = 108448577697623247L;
long rlen; long clen;
public FilterCells(long rlen, long clen) {
this.rlen = rlen;
this.clen = clen;
}
@Override
public Boolean call(Tuple2 kv) throws Exception {
if(kv._1.getRowIndex() <= 0 || kv._1.getColumnIndex() <= 0) {
throw new Exception("Incorrect cell values in TernarySPInstruction:" + kv._1);
}
if(kv._1.getRowIndex() <= rlen && kv._1.getColumnIndex() <= clen) {
return true;
}
return false;
}
}
}