org.apache.sysml.hops.AggUnaryOp Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.hops;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.AggBinaryOp.SparkAggType;
import org.apache.sysml.hops.Hop.MultiThreadedHop;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.lops.Aggregate;
import org.apache.sysml.lops.Aggregate.OperationTypes;
import org.apache.sysml.lops.Binary;
import org.apache.sysml.lops.Group;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.lops.PartialAggregate;
import org.apache.sysml.lops.PartialAggregate.DirectionTypes;
import org.apache.sysml.lops.TernaryAggregate;
import org.apache.sysml.lops.UAggOuterChain;
import org.apache.sysml.lops.UnaryCP;
import org.apache.sysml.lops.LopProperties.ExecType;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
/* Aggregate unary (cell) operation: Sum (aij), col_sum, row_sum
* Properties:
* Symbol: +, min, max, ...
* 1 Operand
*
* Semantic: generate indices, align, aggregate
*/
public class AggUnaryOp extends Hop implements MultiThreadedHop
{
private static final boolean ALLOW_UNARYAGG_WO_FINAL_AGG = true;
private AggOp _op;
private Direction _direction;
private int _maxNumThreads = -1; //-1 for unlimited
private AggUnaryOp() {
//default constructor for clone
}
public AggUnaryOp(String l, DataType dt, ValueType vt, AggOp o, Direction idx, Hop inp)
{
super(l, dt, vt);
_op = o;
_direction = idx;
getInput().add(0, inp);
inp.getParent().add(this);
}
public AggOp getOp()
{
return _op;
}
public void setOp(AggOp op)
{
_op = op;
}
public Direction getDirection()
{
return _direction;
}
public void setDirection(Direction direction)
{
_direction = direction;
}
@Override
public void setMaxNumThreads( int k ) {
_maxNumThreads = k;
}
@Override
public int getMaxNumThreads() {
return _maxNumThreads;
}
@Override
public Lop constructLops()
throws HopsException, LopsException
{
//return already created lops
if( getLops() != null )
return getLops();
try
{
ExecType et = optFindExecType();
Hop input = getInput().get(0);
if ( et == ExecType.CP )
{
Lop agg1 = null;
if( isTernaryAggregateRewriteApplicable(et) ) {
agg1 = constructLopsTernaryAggregateRewrite(et);
}
else if( isUnaryAggregateOuterCPRewriteApplicable() )
{
OperationTypes op = HopsAgg2Lops.get(_op);
DirectionTypes dir = HopsDirection2Lops.get(_direction);
BinaryOp binput = (BinaryOp)getInput().get(0);
agg1 = new UAggOuterChain( binput.getInput().get(0).constructLops(),
binput.getInput().get(1).constructLops(), op, dir,
HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), ExecType.CP);
PartialAggregate.setDimensionsBasedOnDirection(agg1, getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir);
if (getDataType() == DataType.SCALAR) {
UnaryCP unary1 = new UnaryCP(agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR),
getDataType(), getValueType());
unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
setLineNumbers(unary1);
setLops(unary1);
}
}
else { //general case
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
if(DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR || getMemEstimate() < OptimizerUtils.GPU_MEMORY_BUDGET)) {
// Only implemented methods for GPU
if ( (_op == AggOp.SUM && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col))
|| (_op == AggOp.SUM_SQ && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col))
|| (_op == AggOp.MAX && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col))
|| (_op == AggOp.MIN && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col))
|| (_op == AggOp.MEAN && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col))
|| (_op == AggOp.VAR && (_direction == Direction.RowCol || _direction == Direction.Row || _direction == Direction.Col))
|| (_op == AggOp.PROD && (_direction == Direction.RowCol))){
et = ExecType.GPU;
k = 1;
}
}
agg1 = new PartialAggregate(input.constructLops(),
HopsAgg2Lops.get(_op), HopsDirection2Lops.get(_direction), getDataType(),getValueType(), et, k);
}
setOutputDimensions(agg1);
setLineNumbers(agg1);
setLops(agg1);
if (getDataType() == DataType.SCALAR) {
agg1.getOutputParameters().setDimensions(1, 1, getRowsInBlock(), getColsInBlock(), getNnz());
}
}
else if( et == ExecType.MR )
{
OperationTypes op = HopsAgg2Lops.get(_op);
DirectionTypes dir = HopsDirection2Lops.get(_direction);
//unary aggregate operation
Lop transform1 = null;
if( isUnaryAggregateOuterRewriteApplicable() )
{
BinaryOp binput = (BinaryOp)getInput().get(0);
transform1 = new UAggOuterChain( binput.getInput().get(0).constructLops(),
binput.getInput().get(1).constructLops(), op, dir,
HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), ExecType.MR);
PartialAggregate.setDimensionsBasedOnDirection(transform1, getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir);
}
else //default
{
transform1 = new PartialAggregate(input.constructLops(), op, dir, DataType.MATRIX, getValueType());
((PartialAggregate) transform1).setDimensionsBasedOnDirection(getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock());
}
setLineNumbers(transform1);
//aggregation if required
Lop aggregate = null;
Group group1 = null;
Aggregate agg1 = null;
if( requiresAggregation(input, _direction) || transform1 instanceof UAggOuterChain )
{
group1 = new Group(transform1, Group.OperationTypes.Sort, DataType.MATRIX, getValueType());
group1.getOutputParameters().setDimensions(getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), getNnz());
setLineNumbers(group1);
agg1 = new Aggregate(group1, HopsAgg2Lops.get(_op), DataType.MATRIX, getValueType(), et);
agg1.getOutputParameters().setDimensions(getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), getNnz());
agg1.setupCorrectionLocation(PartialAggregate.getCorrectionLocation(op,dir));
setLineNumbers(agg1);
aggregate = agg1;
}
else
{
((PartialAggregate) transform1).setDropCorrection();
aggregate = transform1;
}
setLops(aggregate);
//cast if required
if (getDataType() == DataType.SCALAR) {
// Set the dimensions of PartialAggregate LOP based on the
// direction in which aggregation is performed
PartialAggregate.setDimensionsBasedOnDirection(transform1, input.getDim1(), input.getDim2(),
input.getRowsInBlock(), input.getColsInBlock(), dir);
if( group1 != null && agg1 != null ) { //if aggregation required
group1.getOutputParameters().setDimensions(input.getDim1(), input.getDim2(),
input.getRowsInBlock(), input.getColsInBlock(), getNnz());
agg1.getOutputParameters().setDimensions(1, 1,
input.getRowsInBlock(), input.getColsInBlock(), getNnz());
}
UnaryCP unary1 = new UnaryCP(
aggregate, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR),
getDataType(), getValueType());
unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
setLineNumbers(unary1);
setLops(unary1);
}
}
else if( et == ExecType.SPARK )
{
OperationTypes op = HopsAgg2Lops.get(_op);
DirectionTypes dir = HopsDirection2Lops.get(_direction);
//unary aggregate
if( isTernaryAggregateRewriteApplicable(et) )
{
Lop aggregate = constructLopsTernaryAggregateRewrite(et);
setOutputDimensions(aggregate); //0x0 (scalar)
setLineNumbers(aggregate);
setLops(aggregate);
}
else if( isUnaryAggregateOuterSPRewriteApplicable() )
{
BinaryOp binput = (BinaryOp)getInput().get(0);
Lop transform1 = new UAggOuterChain( binput.getInput().get(0).constructLops(),
binput.getInput().get(1).constructLops(), op, dir,
HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), ExecType.SPARK);
PartialAggregate.setDimensionsBasedOnDirection(transform1, getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir);
setLineNumbers(transform1);
setLops(transform1);
if (getDataType() == DataType.SCALAR) {
UnaryCP unary1 = new UnaryCP(transform1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR),
getDataType(), getValueType());
unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
setLineNumbers(unary1);
setLops(unary1);
}
}
else //default
{
boolean needAgg = requiresAggregation(input, _direction);
SparkAggType aggtype = getSparkUnaryAggregationType(needAgg);
PartialAggregate aggregate = new PartialAggregate(input.constructLops(),
HopsAgg2Lops.get(_op), HopsDirection2Lops.get(_direction), DataType.MATRIX, getValueType(), aggtype, et);
aggregate.setDimensionsBasedOnDirection(getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock());
setLineNumbers(aggregate);
setLops(aggregate);
if (getDataType() == DataType.SCALAR) {
UnaryCP unary1 = new UnaryCP(aggregate, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR),
getDataType(), getValueType());
unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
setLineNumbers(unary1);
setLops(unary1);
}
}
}
}
catch (Exception e) {
throw new HopsException(this.printErrorLocation() + "In AggUnary Hop, error constructing Lops " , e);
}
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
//return created lops
return getLops();
}
@Override
public String getOpString() {
//ua - unary aggregate, for consistency with runtime
String s = "ua(" +
HopsAgg2String.get(_op) +
HopsDirection2String.get(_direction) + ")";
return s;
}
@Override
public boolean allowsAllExecTypes()
{
return true;
}
@Override
protected double computeOutputMemEstimate( long dim1, long dim2, long nnz )
{
double sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz);
return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
}
@Override
protected double computeIntermediateMemEstimate( long dim1, long dim2, long nnz )
{
//default: no additional memory required
double val = 0;
double sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz);
switch( _op ) //see MatrixAggLib for runtime operations
{
case MAX:
case MIN:
//worst-case: column-wise, sparse (temp int count arrays)
if( _direction == Direction.Col )
val = dim2 * OptimizerUtils.INT_SIZE;
break;
case SUM:
case SUM_SQ:
//worst-case correction LASTROW / LASTCOLUMN
if( _direction == Direction.Col ) //(potentially sparse)
val = OptimizerUtils.estimateSizeExactSparsity(1, dim2, sparsity);
else if( _direction == Direction.Row ) //(always dense)
val = OptimizerUtils.estimateSizeExactSparsity(dim1, 1, 1.0);
break;
case MEAN:
//worst-case correction LASTTWOROWS / LASTTWOCOLUMNS
if( _direction == Direction.Col ) //(potentially sparse)
val = OptimizerUtils.estimateSizeExactSparsity(2, dim2, sparsity);
else if( _direction == Direction.Row ) //(always dense)
val = OptimizerUtils.estimateSizeExactSparsity(dim1, 2, 1.0);
break;
case VAR:
//worst-case correction LASTFOURROWS / LASTFOURCOLUMNS
if( _direction == Direction.Col ) //(potentially sparse)
val = OptimizerUtils.estimateSizeExactSparsity(4, dim2, sparsity);
else if( _direction == Direction.Row ) //(always dense)
val = OptimizerUtils.estimateSizeExactSparsity(dim1, 4, 1.0);
break;
case MAXINDEX:
case MININDEX:
Hop hop = getInput().get(0);
if(isUnaryAggregateOuterCPRewriteApplicable())
val = 3 * OptimizerUtils.estimateSizeExactSparsity(1, hop._dim2, 1.0);
else
//worst-case correction LASTCOLUMN
val = OptimizerUtils.estimateSizeExactSparsity(dim1, 1, 1.0);
break;
default:
//no intermediate memory consumption
val = 0;
}
return val;
}
@Override
protected long[] inferOutputCharacteristics( MemoTable memo )
{
long[] ret = null;
Hop input = getInput().get(0);
MatrixCharacteristics mc = memo.getAllInputStats(input);
if( _direction == Direction.Col && mc.colsKnown() )
ret = new long[]{1, mc.getCols(), -1};
else if( _direction == Direction.Row && mc.rowsKnown() )
ret = new long[]{mc.getRows(), 1, -1};
return ret;
}
@Override
protected ExecType optFindExecType() throws HopsException {
checkAndSetForcedPlatform();
ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
//forced / memory-based / threshold-based decision
if( _etypeForced != null )
{
_etype = _etypeForced;
}
else
{
if ( OptimizerUtils.isMemoryBasedOptLevel() )
{
_etype = findExecTypeByMemEstimate();
}
// Choose CP, if the input dimensions are below threshold or if the input is a vector
else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVector() )
{
_etype = ExecType.CP;
}
else
{
_etype = REMOTE;
}
//check for valid CP dimensions and matrix size
checkAndSetInvalidCPDimsAndSize();
}
//spark-specific decision refinement (execute unary aggregate w/ spark input and
//single parent also in spark because it's likely cheap and reduces data transfer)
if( _etype == ExecType.CP && _etypeForced != ExecType.CP
&& !(getInput().get(0) instanceof DataOp) //input is not checkpoint
&& (getInput().get(0).getParent().size()==1 //uagg is only parent, or
|| !requiresAggregation(getInput().get(0), _direction)) //w/o agg
&& getInput().get(0).optFindExecType() == ExecType.SPARK )
{
//pull unary aggregate into spark
_etype = ExecType.SPARK;
}
//mark for recompile (forever)
if( ConfigurationManager.isDynamicRecompilation() && !dimsKnown(true) && _etype==REMOTE ) {
setRequiresRecompile();
}
return _etype;
}
private boolean requiresAggregation( Hop input, Direction dir )
{
if( !ALLOW_UNARYAGG_WO_FINAL_AGG )
return false; //customization not allowed
boolean noAggRequired =
( input.getDim1()>1 && input.getDim1()<=input.getRowsInBlock() && dir==Direction.Col ) //e.g., colSums(X) with nrow(X)<=1000
||( input.getDim2()>1 && input.getDim2()<=input.getColsInBlock() && dir==Direction.Row ); //e.g., rowSums(X) with ncol(X)<=1000
return !noAggRequired;
}
private SparkAggType getSparkUnaryAggregationType( boolean agg )
{
if( !agg )
return SparkAggType.NONE;
if( getDataType()==DataType.SCALAR //in case of scalars the block dims are not set
|| dimsKnown() && getDim1()<=getRowsInBlock() && getDim2()<=getColsInBlock() )
return SparkAggType.SINGLE_BLOCK;
else
return SparkAggType.MULTI_BLOCK;
}
private boolean isTernaryAggregateRewriteApplicable(ExecType et)
throws HopsException
{
boolean ret = false;
//currently we support only sum over binary multiply but potentially
//it can be generalized to any RC aggregate over two common binary operations
if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES && _op == AggOp.SUM &&
(_direction == Direction.RowCol || _direction == Direction.Col) )
{
Hop input1 = getInput().get(0);
if( input1.getParent().size() == 1 && //sum single consumer
input1 instanceof BinaryOp && ((BinaryOp)input1).getOp()==OpOp2.MULT
// As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it.
&& input1.optFindExecType() != ExecType.MR)
{
Hop input11 = input1.getInput().get(0);
Hop input12 = input1.getInput().get(1);
if( input11 instanceof BinaryOp && ((BinaryOp)input11).getOp()==OpOp2.MULT ) {
//ternary, arbitrary matrices but no mv/outer operations.
ret = HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1)
&& HopRewriteUtils.isEqualSize(input11.getInput().get(1), input1)
&& HopRewriteUtils.isEqualSize(input12, input1);
}
else if( input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT ) {
//ternary, arbitrary matrices but no mv/outer operations.
ret = HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1)
&& HopRewriteUtils.isEqualSize(input12.getInput().get(1), input1)
&& HopRewriteUtils.isEqualSize(input11, input1);
}
else {
//binary, arbitrary matrices but no mv/outer operations.
ret = HopRewriteUtils.isEqualSize(input11, input12);
}
}
}
return ret;
}
private static boolean isCompareOperator(OpOp2 opOp2)
{
return (opOp2 == OpOp2.LESS || opOp2 == OpOp2.LESSEQUAL
|| opOp2 == OpOp2.GREATER || opOp2 == OpOp2.GREATEREQUAL
|| opOp2 == OpOp2.EQUAL || opOp2 == OpOp2.NOTEQUAL);
}
private boolean isUnaryAggregateOuterRewriteApplicable()
{
boolean ret = false;
Hop input = getInput().get(0);
if( input instanceof BinaryOp && ((BinaryOp)input).isOuterVectorOperator() )
{
//for special cases, we need to hold the broadcast twice in order to allow for
//an efficient binary search over a plain java array
double factor = (isCompareOperator(((BinaryOp)input).getOp())
&& (_direction == Direction.Row || _direction == Direction.Col || _direction == Direction.RowCol)
&& (_op == AggOp.SUM)) ? 2.0 : 1.0;
factor += (isCompareOperator(((BinaryOp)input).getOp())
&& (_direction == Direction.Row || _direction == Direction.Col)
&& (_op == AggOp.MAXINDEX || _op == AggOp.MININDEX))
? 1.0 : 0.0;
//note: memory constraint only needs to take the rhs into account because the output
//is guaranteed to be an aggregate of <=16KB
Hop right = input.getInput().get(1);
if( (right.dimsKnown() && factor*OptimizerUtils.estimateSize(right.getDim1(), right.getDim2())
< OptimizerUtils.getRemoteMemBudgetMap(true)) //dims known and estimate fits
||(!right.dimsKnown() && factor*right.getOutputMemEstimate()
< OptimizerUtils.getRemoteMemBudgetMap(true)))//dims unknown but worst-case estimate fits
{
ret = true;
}
}
return ret;
}
/**
* This will check if there is sufficient memory locally (twice the size of second matrix, for original and sort data), and remotely (size of second matrix (sorted data)).
* @return true if sufficient memory
*/
private boolean isUnaryAggregateOuterSPRewriteApplicable()
{
boolean ret = false;
Hop input = getInput().get(0);
if( input instanceof BinaryOp && ((BinaryOp)input).isOuterVectorOperator() )
{
//note: both cases (partitioned matrix, and sorted double array), require to
//fit the broadcast twice into the local memory budget. Also, the memory
//constraint only needs to take the rhs into account because the output is
//guaranteed to be an aggregate of <=16KB
Hop right = input.getInput().get(1);
double size = right.dimsKnown() ?
OptimizerUtils.estimateSize(right.getDim1(), right.getDim2()) : //dims known and estimate fits
right.getOutputMemEstimate(); //dims unknown but worst-case estimate fits
if(_op == AggOp.MAXINDEX || _op == AggOp.MININDEX){
double memBudgetExec = SparkExecutionContext.getBroadcastMemoryBudget();
double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
//basic requirement: the broadcast needs to to fit twice in the remote broadcast memory
//and local memory budget because we have to create a partitioned broadcast
//memory and hand it over to the spark context as in-memory object
ret = ( 2*size < memBudgetExec && 2*size < memBudgetLocal );
} else {
if( OptimizerUtils.checkSparkBroadcastMemoryBudget(size) ) {
ret = true;
}
}
}
return ret;
}
/**
* This will check if this is one of the operator from supported LibMatrixOuterAgg library.
* It needs to be Outer, aggregator type SUM, RowIndexMin, RowIndexMax and 6 operators <, <=, >, >=, == and !=
*
*
* @return true if unary aggregate outer
*/
private boolean isUnaryAggregateOuterCPRewriteApplicable()
{
boolean ret = false;
Hop input = getInput().get(0);
if(( input instanceof BinaryOp && ((BinaryOp)input).isOuterVectorOperator() )
&& (_op == AggOp.MAXINDEX || _op == AggOp.MININDEX || _op == AggOp.SUM)
&& (isCompareOperator(((BinaryOp)input).getOp())))
ret = true;
return ret;
}
private Lop constructLopsTernaryAggregateRewrite(ExecType et)
throws HopsException, LopsException
{
Hop input1 = getInput().get(0);
Hop input11 = input1.getInput().get(0);
Hop input12 = input1.getInput().get(1);
Lop in1 = null;
Lop in2 = null;
Lop in3 = null;
if( input11 instanceof BinaryOp && ((BinaryOp)input11).getOp()==OpOp2.MULT )
{
in1 = input11.getInput().get(0).constructLops();
in2 = input11.getInput().get(1).constructLops();
in3 = input12.constructLops();
}
else if( input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT )
{
in1 = input11.constructLops();
in2 = input12.getInput().get(0).constructLops();
in3 = input12.getInput().get(1).constructLops();
}
else
{
in1 = input11.constructLops();
in2 = input12.constructLops();
in3 = new LiteralOp(1).constructLops();
}
//create new ternary aggregate operator
int k = OptimizerUtils.getConstrainedNumThreads( _maxNumThreads );
// The execution type of a unary aggregate instruction should depend on the execution type of inputs to avoid OOM
// Since we only support matrix-vector and not vector-matrix, checking the execution type of input1 should suffice.
ExecType et_input = input1.optFindExecType();
DirectionTypes dir = HopsDirection2Lops.get(_direction);
return new TernaryAggregate(in1, in2, in3, Aggregate.OperationTypes.KahanSum,
Binary.OperationTypes.MULTIPLY, dir, getDataType(), ValueType.DOUBLE, et_input, k);
}
@Override
public void refreshSizeInformation()
{
if (getDataType() != DataType.SCALAR)
{
Hop input = getInput().get(0);
if ( _direction == Direction.Col ) //colwise computations
{
setDim1(1);
setDim2(input.getDim2());
}
else if ( _direction == Direction.Row )
{
setDim1(input.getDim1());
setDim2(1);
}
}
}
@Override
public boolean isTransposeSafe()
{
boolean ret = (_direction == Direction.RowCol) && //full aggregate
(_op == AggOp.SUM || _op == AggOp.SUM_SQ || //valid aggregration functions
_op == AggOp.MIN || _op == AggOp.MAX ||
_op == AggOp.PROD || _op == AggOp.MEAN ||
_op == AggOp.VAR);
//note: trace and maxindex are not transpose-safe.
return ret;
}
@Override
public Object clone() throws CloneNotSupportedException
{
AggUnaryOp ret = new AggUnaryOp();
//copy generic attributes
ret.clone(this, false);
//copy specific attributes
ret._op = _op;
ret._direction = _direction;
ret._maxNumThreads = _maxNumThreads;
return ret;
}
@Override
public boolean compare( Hop that )
{
if( !(that instanceof AggUnaryOp) )
return false;
AggUnaryOp that2 = (AggUnaryOp)that;
return ( _op == that2._op
&& _direction == that2._direction
&& _maxNumThreads == that2._maxNumThreads
&& getInput().get(0) == that2.getInput().get(0));
}
}