org.apache.sysml.parser.BuiltinFunctionExpression Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.parser;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import org.apache.sysml.parser.LanguageException.LanguageErrorCodes;
import org.apache.sysml.runtime.util.ConvolutionUtils;
public class BuiltinFunctionExpression extends DataIdentifier
{
protected Expression[] _args = null;
private BuiltinFunctionOp _opcode;
public BuiltinFunctionExpression(BuiltinFunctionOp bifop, ArrayList args, String fname, int blp, int bcp, int elp, int ecp) {
_kind = Kind.BuiltinFunctionOp;
_opcode = bifop;
this.setAllPositions(fname, blp, bcp, elp, ecp);
args = expandConvolutionArguments(args);
_args = new Expression[args.size()];
for(int i=0; i < args.size(); i++) {
_args[i] = args.get(i).getExpr();
}
}
public BuiltinFunctionExpression(BuiltinFunctionOp bifop, Expression[] args, String fname, int blp, int bcp, int elp, int ecp) {
_kind = Kind.BuiltinFunctionOp;
_opcode = bifop;
_args = new Expression[args.length];
for(int i=0; i < args.length; i++) {
_args[i] = args[i];
}
this.setAllPositions(fname, blp, bcp, elp, ecp);
}
public Expression rewriteExpression(String prefix) throws LanguageException {
Expression[] newArgs = new Expression[_args.length];
for(int i=0; i < _args.length; i++) {
newArgs[i] = _args[i].rewriteExpression(prefix);
}
BuiltinFunctionExpression retVal = new BuiltinFunctionExpression(this._opcode, newArgs,
this.getFilename(), this.getBeginLine(), this.getBeginColumn(), this.getEndLine(), this.getEndColumn());
return retVal;
}
public BuiltinFunctionOp getOpCode() {
return _opcode;
}
public Expression getFirstExpr() {
return (_args.length >= 1 ? _args[0] : null);
}
public Expression getSecondExpr() {
return (_args.length >= 2 ? _args[1] : null);
}
public Expression getThirdExpr() {
return (_args.length >= 3 ? _args[2] : null);
}
public Expression[] getAllExpr(){
return _args;
}
@Override
public void validateExpression(MultiAssignmentStatement stmt, HashMap ids, HashMap constVars, boolean conditional)
throws LanguageException
{
if (this.getFirstExpr() instanceof FunctionCallIdentifier){
raiseValidateError("UDF function call not supported as parameter to built-in function call", false);
}
this.getFirstExpr().validateExpression(ids, constVars, conditional);
if (getSecondExpr() != null){
if (this.getSecondExpr() instanceof FunctionCallIdentifier){
raiseValidateError("UDF function call not supported as parameter to built-in function call", false);
}
getSecondExpr().validateExpression(ids, constVars, conditional);
}
if (getThirdExpr() != null) {
if (this.getThirdExpr() instanceof FunctionCallIdentifier){
raiseValidateError("UDF function call not supported as parameter to built-in function call", false);
}
getThirdExpr().validateExpression(ids, constVars, conditional);
}
_outputs = new Identifier[stmt.getTargetList().size()];
int count = 0;
for (DataIdentifier outParam: stmt.getTargetList()){
DataIdentifier tmp = new DataIdentifier(outParam);
tmp.setAllPositions(this.getFilename(), this.getBeginLine(), this.getBeginColumn(), this.getEndLine(), this.getEndColumn());
_outputs[count++] = tmp;
}
switch (_opcode) {
case QR:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
// setup output properties
DataIdentifier qrOut1 = (DataIdentifier) getOutputs()[0];
DataIdentifier qrOut2 = (DataIdentifier) getOutputs()[1];
long rows = getFirstExpr().getOutput().getDim1();
long cols = getFirstExpr().getOutput().getDim2();
// Output1 - Q
qrOut1.setDataType(DataType.MATRIX);
qrOut1.setValueType(ValueType.DOUBLE);
qrOut1.setDimensions(rows, cols);
qrOut1.setBlockDimensions(getFirstExpr().getOutput().getRowsInBlock(), getFirstExpr().getOutput().getColumnsInBlock());
// Output2 - R
qrOut2.setDataType(DataType.MATRIX);
qrOut2.setValueType(ValueType.DOUBLE);
qrOut2.setDimensions(rows, cols);
qrOut2.setBlockDimensions(getFirstExpr().getOutput().getRowsInBlock(), getFirstExpr().getOutput().getColumnsInBlock());
break;
case LU:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
// setup output properties
DataIdentifier luOut1 = (DataIdentifier) getOutputs()[0];
DataIdentifier luOut2 = (DataIdentifier) getOutputs()[1];
DataIdentifier luOut3 = (DataIdentifier) getOutputs()[2];
long inrows = getFirstExpr().getOutput().getDim1();
long incols = getFirstExpr().getOutput().getDim2();
if ( inrows != incols ) {
raiseValidateError("LU Decomposition can only be done on a square matrix. Input matrix is rectangular (rows=" + inrows + ", cols="+incols+")", conditional);
}
// Output1 - P
luOut1.setDataType(DataType.MATRIX);
luOut1.setValueType(ValueType.DOUBLE);
luOut1.setDimensions(inrows, inrows);
luOut1.setBlockDimensions(getFirstExpr().getOutput().getRowsInBlock(), getFirstExpr().getOutput().getColumnsInBlock());
// Output2 - L
luOut2.setDataType(DataType.MATRIX);
luOut2.setValueType(ValueType.DOUBLE);
luOut2.setDimensions(inrows, inrows);
luOut2.setBlockDimensions(getFirstExpr().getOutput().getRowsInBlock(), getFirstExpr().getOutput().getColumnsInBlock());
// Output3 - U
luOut3.setDataType(DataType.MATRIX);
luOut3.setValueType(ValueType.DOUBLE);
luOut3.setDimensions(inrows, inrows);
luOut3.setBlockDimensions(getFirstExpr().getOutput().getRowsInBlock(), getFirstExpr().getOutput().getColumnsInBlock());
break;
case EIGEN:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
// setup output properties
DataIdentifier eigenOut1 = (DataIdentifier) getOutputs()[0];
DataIdentifier eigenOut2 = (DataIdentifier) getOutputs()[1];
if ( getFirstExpr().getOutput().getDim1() != getFirstExpr().getOutput().getDim2() ) {
raiseValidateError("Eigen Decomposition can only be done on a square matrix. Input matrix is rectangular (rows=" + getFirstExpr().getOutput().getDim1() + ", cols="+ getFirstExpr().getOutput().getDim2() +")", conditional);
}
// Output1 - Eigen Values
eigenOut1.setDataType(DataType.MATRIX);
eigenOut1.setValueType(ValueType.DOUBLE);
eigenOut1.setDimensions(getFirstExpr().getOutput().getDim1(), 1);
eigenOut1.setBlockDimensions(getFirstExpr().getOutput().getRowsInBlock(), getFirstExpr().getOutput().getColumnsInBlock());
// Output2 - Eigen Vectors
eigenOut2.setDataType(DataType.MATRIX);
eigenOut2.setValueType(ValueType.DOUBLE);
eigenOut2.setDimensions(getFirstExpr().getOutput().getDim1(), getFirstExpr().getOutput().getDim2());
eigenOut2.setBlockDimensions(getFirstExpr().getOutput().getRowsInBlock(), getFirstExpr().getOutput().getColumnsInBlock());
break;
default: //always unconditional
raiseValidateError("Unknown Builtin Function opcode: " + _opcode, false);
}
}
private ArrayList orderConvolutionParams(ArrayList paramExpression,
int skip) throws LanguageException {
ArrayList newParams = new ArrayList();
for(int i = 0; i < skip; i++)
newParams.add(paramExpression.get(i));
String [] orderedParams = {
"stride1", "stride2", "padding1", "padding2",
"input_shape1", "input_shape2", "input_shape3", "input_shape4",
"filter_shape1", "filter_shape2", "filter_shape3", "filter_shape4"
};
for(int i = 0; i < orderedParams.length; i++) {
boolean found = false;
for(ParameterExpression param : paramExpression) {
if(param.getName() != null && param.getName().equals(orderedParams[i])) {
found = true;
newParams.add(param);
}
}
if(!found) {
throw new LanguageException("Incorrect parameters. Expected " + orderedParams[i] + " to be expanded.");
}
}
return newParams;
}
private ArrayList replaceListParams(ArrayList paramExpression,
String inputVarName, String outputVarName, int startIndex) throws LanguageException {
ArrayList newParamExpression = new ArrayList();
int i = startIndex;
int j = 1; // Assumption: sequential ordering pool_size1, pool_size2
for (ParameterExpression expr : paramExpression) {
if(expr.getName() != null && expr.getName().equals(inputVarName + j)) {
newParamExpression.add(new ParameterExpression(outputVarName + i, expr.getExpr()));
i++; j++;
}
else {
newParamExpression.add(expr);
}
}
return newParamExpression;
}
private ArrayList expandListParams(ArrayList paramExpression,
HashSet paramsToExpand) throws LanguageException {
ArrayList newParamExpressions = new ArrayList();
for(ParameterExpression expr : paramExpression) {
if(paramsToExpand.contains(expr.getName())) {
if(expr.getExpr() instanceof ExpressionList) {
int i = 1;
for(Expression e : ((ExpressionList)expr.getExpr()).getValue()) {
newParamExpressions.add(new ParameterExpression(expr.getName() + i, e));
i++;
}
}
}
else if(expr.getExpr() instanceof ExpressionList) {
throw new LanguageException("The parameter " + expr.getName() + " cannot be list or is not supported for the given function");
}
else {
newParamExpressions.add(expr);
}
}
return newParamExpressions;
}
private ArrayList expandConvolutionArguments(ArrayList paramExpression) {
try {
if(_opcode == BuiltinFunctionOp.CONV2D || _opcode == BuiltinFunctionOp.CONV2D_BACKWARD_FILTER
|| _opcode == BuiltinFunctionOp.CONV2D_BACKWARD_DATA) {
HashSet expand = new HashSet();
expand.add("input_shape"); expand.add("filter_shape"); expand.add("stride"); expand.add("padding");
paramExpression = expandListParams(paramExpression, expand);
paramExpression = orderConvolutionParams(paramExpression, 2);
}
else if(_opcode == BuiltinFunctionOp.MAX_POOL ||
_opcode == BuiltinFunctionOp.MAX_POOL_BACKWARD) {
HashSet expand = new HashSet();
expand.add("input_shape"); expand.add("pool_size"); expand.add("stride"); expand.add("padding");
paramExpression = expandListParams(paramExpression, expand);
paramExpression.add(new ParameterExpression("filter_shape1",
new IntIdentifier(1, getFilename(), getBeginLine(), getBeginColumn(), getEndLine(), getEndColumn())));
paramExpression.add(new ParameterExpression("filter_shape2",
new IntIdentifier(1, getFilename(), getBeginLine(), getBeginColumn(), getEndLine(), getEndColumn())));
paramExpression = replaceListParams(paramExpression, "pool_size", "filter_shape", 3);
if(_opcode == BuiltinFunctionOp.MAX_POOL_BACKWARD)
paramExpression = orderConvolutionParams(paramExpression, 2);
else
paramExpression = orderConvolutionParams(paramExpression, 1);
}
}
catch(LanguageException e) {
throw new RuntimeException(e);
}
return paramExpression;
}
/**
* Validate parse tree : Process BuiltinFunction Expression in an assignment
* statement
*
* @throws LanguageException
*/
public void validateExpression(HashMap ids, HashMap constVars, boolean conditional)
throws LanguageException {
for(int i=0; i < _args.length; i++ ) {
if (_args[i] instanceof FunctionCallIdentifier){
raiseValidateError("UDF function call not supported as parameter to built-in function call", false);
}
_args[i].validateExpression(ids, constVars, conditional);
}
// checkIdentifierParams();
String outputName = getTempName();
DataIdentifier output = new DataIdentifier(outputName);
output.setAllPositions(this.getFilename(), this.getBeginLine(), this.getBeginColumn(), this.getEndLine(), this.getEndColumn());
Identifier id = this.getFirstExpr().getOutput();
output.setProperties(this.getFirstExpr().getOutput());
output.setNnz(-1); //conservatively, cannot use input nnz!
this.setOutput(output);
switch (this.getOpCode()) {
case COLSUM:
case COLMAX:
case COLMIN:
case COLMEAN:
case COLSD:
case COLVAR:
// colSums(X);
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setDimensions(1, id.getDim2());
output.setBlockDimensions (id.getRowsInBlock(), id.getColumnsInBlock());
output.setValueType(id.getValueType());
break;
case ROWSUM:
case ROWMAX:
case ROWINDEXMAX:
case ROWMIN:
case ROWINDEXMIN:
case ROWMEAN:
case ROWSD:
case ROWVAR:
//rowSums(X);
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setDimensions(id.getDim1(), 1);
output.setBlockDimensions (id.getRowsInBlock(), id.getColumnsInBlock());
output.setValueType(id.getValueType());
break;
case SUM:
case PROD:
case TRACE:
case SD:
case VAR:
// sum(X);
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.SCALAR);
output.setDimensions(0, 0);
output.setBlockDimensions (0, 0);
output.setValueType(id.getValueType());
break;
case MEAN:
//checkNumParameters(2, false); // mean(Y) or mean(Y,W)
if (getSecondExpr() != null) {
checkNumParameters (2);
}
else {
checkNumParameters (1);
}
checkMatrixParam(getFirstExpr());
if ( getSecondExpr() != null ) {
// x = mean(Y,W);
checkMatchingDimensions(getFirstExpr(), getSecondExpr());
}
output.setDataType(DataType.SCALAR);
output.setDimensions(0, 0);
output.setBlockDimensions (0, 0);
output.setValueType(id.getValueType());
break;
case MIN:
case MAX:
//min(X), min(X,s), min(s,X), min(s,r), min(X,Y)
//unary aggregate
if (getSecondExpr() == null)
{
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType( DataType.SCALAR );
output.setDimensions(0, 0);
output.setBlockDimensions (0, 0);
}
//binary operation
else
{
checkNumParameters(2);
DataType dt1 = getFirstExpr().getOutput().getDataType();
DataType dt2 = getSecondExpr().getOutput().getDataType();
DataType dtOut = (dt1==DataType.MATRIX || dt2==DataType.MATRIX)?
DataType.MATRIX : DataType.SCALAR;
if( dt1==DataType.MATRIX && dt2==DataType.MATRIX )
checkMatchingDimensions(getFirstExpr(), getSecondExpr(), true);
//determine output dimensions
long[] dims = getBinaryMatrixCharacteristics(getFirstExpr(), getSecondExpr());
output.setDataType( dtOut );
output.setDimensions(dims[0], dims[1]);
output.setBlockDimensions (dims[2], dims[3]);
}
output.setValueType(id.getValueType());
break;
case CUMSUM:
case CUMPROD:
case CUMMIN:
case CUMMAX:
// cumsum(X);
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setDimensions(id.getDim1(), id.getDim2());
output.setBlockDimensions (id.getRowsInBlock(), id.getColumnsInBlock());
output.setValueType(id.getValueType());
break;
case CAST_AS_SCALAR:
checkNumParameters(1);
checkMatrixFrameParam(getFirstExpr());
if (( getFirstExpr().getOutput().getDim1() != -1 && getFirstExpr().getOutput().getDim1() !=1) || ( getFirstExpr().getOutput().getDim2() != -1 && getFirstExpr().getOutput().getDim2() !=1)) {
raiseValidateError("dimension mismatch while casting matrix to scalar: dim1: " + getFirstExpr().getOutput().getDim1() + " dim2 " + getFirstExpr().getOutput().getDim2(),
conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
output.setDataType(DataType.SCALAR);
output.setDimensions(0, 0);
output.setBlockDimensions (0, 0);
output.setValueType(id.getValueType());
break;
case CAST_AS_MATRIX:
checkNumParameters(1);
checkScalarFrameParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setDimensions(id.getDim1(), id.getDim2());
if( getFirstExpr().getOutput().getDataType()==DataType.SCALAR )
output.setDimensions(1, 1); //correction scalars
output.setBlockDimensions(id.getRowsInBlock(), id.getColumnsInBlock());
output.setValueType(id.getValueType());
break;
case CAST_AS_FRAME:
checkNumParameters(1);
checkMatrixScalarParam(getFirstExpr());
output.setDataType(DataType.FRAME);
output.setDimensions(id.getDim1(), id.getDim2());
if( getFirstExpr().getOutput().getDataType()==DataType.SCALAR )
output.setDimensions(1, 1); //correction scalars
output.setBlockDimensions(id.getRowsInBlock(), id.getColumnsInBlock());
output.setValueType(id.getValueType());
break;
case CAST_AS_DOUBLE:
checkNumParameters(1);
checkScalarParam(getFirstExpr());
output.setDataType(DataType.SCALAR);
//output.setDataType(id.getDataType()); //TODO whenever we support multiple matrix value types, currently noop.
output.setDimensions(0, 0);
output.setBlockDimensions (0, 0);
output.setValueType(ValueType.DOUBLE);
break;
case CAST_AS_INT:
checkNumParameters(1);
checkScalarParam(getFirstExpr());
output.setDataType(DataType.SCALAR);
//output.setDataType(id.getDataType()); //TODO whenever we support multiple matrix value types, currently noop.
output.setDimensions(0, 0);
output.setBlockDimensions (0, 0);
output.setValueType(ValueType.INT);
break;
case CAST_AS_BOOLEAN:
checkNumParameters(1);
checkScalarParam(getFirstExpr());
output.setDataType(DataType.SCALAR);
//output.setDataType(id.getDataType()); //TODO whenever we support multiple matrix value types, currently noop.
output.setDimensions(0, 0);
output.setBlockDimensions (0, 0);
output.setValueType(ValueType.BOOLEAN);
break;
case CBIND:
case RBIND:
checkNumParameters(2);
//scalar string append (string concatenation with \n)
if( getFirstExpr().getOutput().getDataType()==DataType.SCALAR ) {
checkScalarParam(getFirstExpr());
checkScalarParam(getSecondExpr());
checkValueTypeParam(getFirstExpr(), ValueType.STRING);
checkValueTypeParam(getSecondExpr(), ValueType.STRING);
}
//matrix append (rbind/cbind)
else {
checkMatrixFrameParam(getFirstExpr());
checkMatrixFrameParam(getSecondExpr());
}
output.setDataType(id.getDataType());
output.setValueType(id.getValueType());
// set output dimensions and validate consistency
long appendDim1 = -1, appendDim2 = -1;
long m1rlen = getFirstExpr().getOutput().getDim1();
long m1clen = getFirstExpr().getOutput().getDim2();
long m2rlen = getSecondExpr().getOutput().getDim1();
long m2clen = getSecondExpr().getOutput().getDim2();
if( getOpCode() == BuiltinFunctionOp.CBIND ) {
if (m1rlen > 0 && m2rlen > 0 && m1rlen!=m2rlen) {
raiseValidateError("inputs to cbind must have same number of rows: input 1 rows: " +
m1rlen+", input 2 rows: "+m2rlen, conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
appendDim1 = (m1rlen>0) ? m1rlen : m2rlen;
appendDim2 = (m1clen>0 && m2clen>0)? m1clen + m2clen : -1;
}
else if( getOpCode() == BuiltinFunctionOp.RBIND ) {
if (m1clen > 0 && m2clen > 0 && m1clen!=m2clen) {
raiseValidateError("inputs to rbind must have same number of columns: input 1 columns: " +
m1clen+", input 2 columns: "+m2clen, conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
appendDim1 = (m1rlen>0 && m2rlen>0)? m1rlen + m2rlen : -1;
appendDim2 = (m1clen>0) ? m1clen : m2clen;
}
output.setDimensions(appendDim1, appendDim2);
output.setBlockDimensions (id.getRowsInBlock(), id.getColumnsInBlock());
break;
case PPRED:
// TODO: remove this when ppred has been removed from DML
raiseValidateError("ppred() has been deprecated. Please use the operator directly.", true);
// ppred (X,Y, "<"); ppred (X,y, "<"); ppred (y,X, "<");
checkNumParameters(3);
DataType dt1 = getFirstExpr().getOutput().getDataType();
DataType dt2 = getSecondExpr().getOutput().getDataType();
//check input data types
if( dt1 == DataType.SCALAR && dt2 == DataType.SCALAR ) {
raiseValidateError("ppred() requires at least one matrix input.", conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
if( dt1 == DataType.MATRIX )
checkMatrixParam(getFirstExpr());
if( dt2 == DataType.MATRIX )
checkMatrixParam(getSecondExpr());
if( dt1==DataType.MATRIX && dt2==DataType.MATRIX ) //dt1==dt2
checkMatchingDimensions(getFirstExpr(), getSecondExpr(), true);
//check operator
if (getThirdExpr().getOutput().getDataType() != DataType.SCALAR ||
getThirdExpr().getOutput().getValueType() != ValueType.STRING)
{
raiseValidateError("Third argument in ppred() is not an operator ", conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
//determine output dimensions
long[] dims = getBinaryMatrixCharacteristics(getFirstExpr(), getSecondExpr());
output.setDataType(DataType.MATRIX);
output.setDimensions(dims[0], dims[1]);
output.setBlockDimensions(dims[2], dims[3]);
output.setValueType(id.getValueType());
break;
case TRANS:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setDimensions(id.getDim2(), id.getDim1());
output.setBlockDimensions (id.getColumnsInBlock(), id.getRowsInBlock());
output.setValueType(id.getValueType());
break;
case REV:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setDimensions(id.getDim1(), id.getDim2());
output.setBlockDimensions (id.getColumnsInBlock(), id.getRowsInBlock());
output.setValueType(id.getValueType());
break;
case DIAG:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
if( id.getDim2() != -1 ) { //type known
if ( id.getDim2() == 1 )
{
//diag V2M
output.setDimensions(id.getDim1(), id.getDim1());
}
else
{
if (id.getDim1() != id.getDim2()) {
raiseValidateError("Invoking diag on matrix with dimensions ("
+ id.getDim1() + "," + id.getDim2()
+ ") in " + this.toString(), conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
//diag M2V
output.setDimensions(id.getDim1(), 1);
}
}
output.setBlockDimensions (id.getRowsInBlock(), id.getColumnsInBlock());
output.setValueType(id.getValueType());
break;
case NROW:
case NCOL:
case LENGTH:
checkNumParameters(1);
checkMatrixFrameParam(getFirstExpr());
output.setDataType(DataType.SCALAR);
output.setDimensions(0, 0);
output.setBlockDimensions (0, 0);
output.setValueType(ValueType.INT);
break;
// Contingency tables
case TABLE:
/*
* Allowed #of arguments: 2,3,4,5
* table(A,B)
* table(A,B,W)
* table(A,B,1)
* table(A,B,dim1,dim2)
* table(A,B,W,dim1,dim2)
* table(A,B,1,dim1,dim2)
*/
// Check for validity of input arguments, and setup output dimensions
// First input: is always of type MATRIX
checkMatrixParam(getFirstExpr());
if ( getSecondExpr() == null )
raiseValidateError("Invalid number of arguments to table(): "
+ this.toString(), conditional, LanguageErrorCodes.INVALID_PARAMETERS);
// Second input: can be MATRIX or SCALAR
// cases: table(A,B) or table(A,1)
if ( getSecondExpr().getOutput().getDataType() == DataType.MATRIX)
checkMatchingDimensions(getFirstExpr(),getSecondExpr());
long outputDim1=-1, outputDim2=-1;
switch(_args.length) {
case 2:
// nothing to do
break;
case 3:
// case - table w/ weights
// - weights specified as a matrix: table(A,B,W) or table(A,1,W)
// - weights specified as a scalar: table(A,B,1) or table(A,1,1)
if ( getThirdExpr().getOutput().getDataType() == DataType.MATRIX)
checkMatchingDimensions(getFirstExpr(),getThirdExpr());
break;
case 4:
// case - table w/ output dimensions: table(A,B,dim1,dim2) or table(A,1,dim1,dim2)
// third and fourth arguments must be scalars
if ( getThirdExpr().getOutput().getDataType() != DataType.SCALAR || _args[3].getOutput().getDataType() != DataType.SCALAR ) {
raiseValidateError("Invalid argument types to table(): output dimensions must be of type scalar: "
+ this.toString(), conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
else {
// constant propagation
if( getThirdExpr() instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)getThirdExpr()).getName()) )
_args[2] = constVars.get(((DataIdentifier)getThirdExpr()).getName());
if( _args[3] instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)_args[3]).getName()) )
_args[3] = constVars.get(((DataIdentifier)_args[3]).getName());
if ( getThirdExpr().getOutput() instanceof ConstIdentifier )
outputDim1 = ((ConstIdentifier) getThirdExpr().getOutput()).getLongValue();
if ( _args[3].getOutput() instanceof ConstIdentifier )
outputDim2 = ((ConstIdentifier) _args[3].getOutput()).getLongValue();
}
break;
case 5:
// case - table w/ weights and output dimensions:
// - table(A,B,W,dim1,dim2) or table(A,1,W,dim1,dim2)
// - table(A,B,1,dim1,dim2) or table(A,1,1,dim1,dim2)
if ( getThirdExpr().getOutput().getDataType() == DataType.MATRIX)
checkMatchingDimensions(getFirstExpr(),getThirdExpr());
// fourth and fifth arguments must be scalars
if ( _args[3].getOutput().getDataType() != DataType.SCALAR || _args[4].getOutput().getDataType() != DataType.SCALAR ) {
raiseValidateError("Invalid argument types to table(): output dimensions must be of type scalar: "
+ this.toString(), conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
else {
// constant propagation
if( _args[3] instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)_args[3]).getName()) )
_args[3] = constVars.get(((DataIdentifier)_args[3]).getName());
if( _args[4] instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)_args[4]).getName()) )
_args[4] = constVars.get(((DataIdentifier)_args[4]).getName());
if ( _args[3].getOutput() instanceof ConstIdentifier )
outputDim1 = ((ConstIdentifier) _args[3].getOutput()).getLongValue();
if ( _args[4].getOutput() instanceof ConstIdentifier )
outputDim2 = ((ConstIdentifier) _args[4].getOutput()).getLongValue();
}
break;
default:
raiseValidateError("Invalid number of arguments to table(): "
+ this.toString(), conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
// The dimensions for the output matrix will be known only at the
// run time
output.setDimensions(outputDim1, outputDim2);
output.setBlockDimensions (-1, -1);
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.DOUBLE);
break;
case MOMENT:
/*
* x = centralMoment(V,order) or xw = centralMoment(V,W,order)
*/
checkMatrixParam(getFirstExpr());
if (getThirdExpr() != null) {
checkNumParameters(3);
checkMatrixParam(getSecondExpr());
checkMatchingDimensions(getFirstExpr(),getSecondExpr());
checkScalarParam(getThirdExpr());
}
else {
checkNumParameters(2);
checkScalarParam(getSecondExpr());
}
// output is a scalar
output.setDataType(DataType.SCALAR);
output.setValueType(ValueType.DOUBLE);
output.setDimensions(0, 0);
output.setBlockDimensions(0,0);
break;
case COV:
/*
* x = cov(V1,V2) or xw = cov(V1,V2,W)
*/
if (getThirdExpr() != null) {
checkNumParameters(3);
}
else {
checkNumParameters(2);
}
checkMatrixParam(getFirstExpr());
checkMatrixParam(getSecondExpr());
checkMatchingDimensions(getFirstExpr(),getSecondExpr());
if (getThirdExpr() != null) {
checkMatrixParam(getThirdExpr());
checkMatchingDimensions(getFirstExpr(), getThirdExpr());
}
// output is a scalar
output.setDataType(DataType.SCALAR);
output.setValueType(ValueType.DOUBLE);
output.setDimensions(0, 0);
output.setBlockDimensions(0,0);
break;
case QUANTILE:
/*
* q = quantile(V1,0.5) computes median in V1
* or Q = quantile(V1,P) computes the vector of quantiles as specified by P
* or qw = quantile(V1,W,0.5) computes median when weights (W) are given
* or QW = quantile(V1,W,P) computes the vector of quantiles as specified by P, when weights (W) are given
*/
if(getThirdExpr() != null) {
checkNumParameters(3);
}
else {
checkNumParameters(2);
}
// first parameter must always be a 1D matrix
check1DMatrixParam(getFirstExpr());
// check for matching dimensions for other matrix parameters
if (getThirdExpr() != null) {
checkMatrixParam(getSecondExpr());
checkMatchingDimensions(getFirstExpr(), getSecondExpr());
}
// set the properties for _output expression
// output dimensions = dimensions of second, if third is null
// = dimensions of the third, otherwise.
if (getThirdExpr() != null) {
output.setDimensions(getThirdExpr().getOutput().getDim1(), getThirdExpr().getOutput()
.getDim2());
output.setBlockDimensions(getThirdExpr().getOutput().getRowsInBlock(),
getThirdExpr().getOutput().getColumnsInBlock());
output.setDataType(getThirdExpr().getOutput().getDataType());
} else {
output.setDimensions(getSecondExpr().getOutput().getDim1(), getSecondExpr().getOutput()
.getDim2());
output.setBlockDimensions(getSecondExpr().getOutput().getRowsInBlock(),
getSecondExpr().getOutput().getColumnsInBlock());
output.setDataType(getSecondExpr().getOutput().getDataType());
}
break;
case INTERQUANTILE:
if (getThirdExpr() != null) {
checkNumParameters(3);
}
else {
checkNumParameters(2);
}
checkMatrixParam(getFirstExpr());
if (getThirdExpr() != null) {
// i.e., second input is weight vector
checkMatrixParam(getSecondExpr());
checkMatchingDimensionsQuantile();
}
if ((getThirdExpr() == null && getSecondExpr().getOutput().getDataType() != DataType.SCALAR)
&& (getThirdExpr() != null && getThirdExpr().getOutput().getDataType() != DataType.SCALAR)) {
raiseValidateError("Invalid parameters to "+ this.getOpCode(), conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
output.setValueType(id.getValueType());
// output dimensions are unknown
output.setDimensions(-1, -1);
output.setBlockDimensions(-1,-1);
output.setDataType(DataType.MATRIX);
break;
case IQM:
/*
* Usage: iqm = InterQuartileMean(A,W); iqm = InterQuartileMean(A);
*/
if (getSecondExpr() != null){
checkNumParameters(2);
}
else {
checkNumParameters(1);
}
checkMatrixParam(getFirstExpr());
if (getSecondExpr() != null) {
// i.e., second input is weight vector
checkMatrixParam(getSecondExpr());
checkMatchingDimensions(getFirstExpr(), getSecondExpr());
}
// Output is a scalar
output.setValueType(id.getValueType());
output.setDimensions(0, 0);
output.setBlockDimensions(0,0);
output.setDataType(DataType.SCALAR);
break;
case MEDIAN:
if (getSecondExpr() != null){
checkNumParameters(2);
}
else {
checkNumParameters(1);
}
checkMatrixParam(getFirstExpr());
if (getSecondExpr() != null) {
// i.e., second input is weight vector
checkMatrixParam(getSecondExpr());
checkMatchingDimensions(getFirstExpr(), getSecondExpr());
}
// Output is a scalar
output.setValueType(id.getValueType());
output.setDimensions(0, 0);
output.setBlockDimensions(0,0);
output.setDataType(DataType.SCALAR);
break;
case SAMPLE:
{
Expression[] in = getAllExpr();
for(Expression e : in)
checkScalarParam(e);
if (in[0].getOutput().getValueType() != ValueType.DOUBLE && in[0].getOutput().getValueType() != ValueType.INT)
throw new LanguageException("First argument to sample() must be a number.");
if (in[1].getOutput().getValueType() != ValueType.DOUBLE && in[1].getOutput().getValueType() != ValueType.INT)
throw new LanguageException("Second argument to sample() must be a number.");
boolean check = false;
if ( isConstant(in[0]) && isConstant(in[1]) )
{
long range = ((ConstIdentifier)in[0]).getLongValue();
long size = ((ConstIdentifier)in[1]).getLongValue();
if ( range < size )
check = true;
}
if(in.length == 4 )
{
checkNumParameters(4);
if (in[3].getOutput().getValueType() != ValueType.INT)
throw new LanguageException("Fourth arugment, seed, to sample() must be an integer value.");
if (in[2].getOutput().getValueType() != ValueType.BOOLEAN )
throw new LanguageException("Third arugment to sample() must either denote replacement policy (boolean) or seed (integer).");
}
else if(in.length == 3)
{
checkNumParameters(3);
if (in[2].getOutput().getValueType() != ValueType.BOOLEAN
&& in[2].getOutput().getValueType() != ValueType.INT )
throw new LanguageException("Third arugment to sample() must either denote replacement policy (boolean) or seed (integer).");
}
if ( check && in.length >= 3
&& isConstant(in[2])
&& in[2].getOutput().getValueType() == ValueType.BOOLEAN
&& !((BooleanIdentifier)in[2]).getValue() )
throw new LanguageException("Sample (size=" + ((ConstIdentifier)in[0]).getLongValue()
+ ") larger than population (size=" + ((ConstIdentifier)in[1]).getLongValue()
+ ") can only be generated with replacement.");
// Output is a column vector
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.DOUBLE);
if ( isConstant(in[1]) )
output.setDimensions(((ConstIdentifier)in[1]).getLongValue(), 1);
else
output.setDimensions(-1, 1);
setBlockDimensions(id.getRowsInBlock(), id.getColumnsInBlock());
break;
}
case SEQ:
//basic parameter validation
checkScalarParam(getFirstExpr());
checkScalarParam(getSecondExpr());
if ( getThirdExpr() != null ) {
checkNumParameters(3);
checkScalarParam(getThirdExpr());
}
else
checkNumParameters(2);
// constant propagation (from, to, incr)
if( getFirstExpr() instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)getFirstExpr()).getName()) )
_args[0] = constVars.get(((DataIdentifier)getFirstExpr()).getName());
if( getSecondExpr() instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)getSecondExpr()).getName()) )
_args[1] = constVars.get(((DataIdentifier)getSecondExpr()).getName());
if( getThirdExpr()!=null && getThirdExpr() instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)getThirdExpr()).getName()) )
_args[2] = constVars.get(((DataIdentifier)getThirdExpr()).getName());
// check if dimensions can be inferred
long dim1=-1, dim2=1;
if ( isConstant(getFirstExpr()) && isConstant(getSecondExpr()) && (getThirdExpr() != null ? isConstant(getThirdExpr()) : true) ) {
double from, to, incr;
try {
from = getDoubleValue(getFirstExpr());
to = getDoubleValue(getSecondExpr());
// Setup the value of increment
// default value: 1 if from <= to; -1 if from > to
if(getThirdExpr() == null) {
expandArguments();
_args[2] = new DoubleIdentifier(((from > to) ? -1.0 : 1.0),
this.getFilename(), this.getBeginLine(), this.getBeginColumn(),
this.getEndLine(), this.getEndColumn());
}
incr = getDoubleValue(getThirdExpr());
}
catch (LanguageException e) {
throw new LanguageException("Arguments for seq() must be numeric.");
}
if( (from > to) && (incr >= 0) )
throw new LanguageException("Wrong sign for the increment in a call to seq()");
// Both end points of the range must included i.e., [from,to] both inclusive.
// Note that, "to" is included only if (to-from) is perfectly divisible by incr
// For example, seq(0,1,0.5) produces (0.0 0.5 1.0) whereas seq(0,1,0.6) produces only (0.0 0.6) but not (0.0 0.6 1.0)
dim1 = 1 + (long)Math.floor((to-from)/incr);
}
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.DOUBLE);
output.setDimensions(dim1, dim2);
output.setBlockDimensions(0, 0);
break;
case SOLVE:
checkNumParameters(2);
checkMatrixParam(getFirstExpr());
checkMatrixParam(getSecondExpr());
if ( getSecondExpr().getOutput().dimsKnown() && !is1DMatrix(getSecondExpr()) )
raiseValidateError("Second input to solve() must be a vector", conditional);
if ( getFirstExpr().getOutput().dimsKnown() && getSecondExpr().getOutput().dimsKnown() &&
getFirstExpr().getOutput().getDim1() != getSecondExpr().getOutput().getDim1() )
raiseValidateError("Dimension mismatch in a call to solve()", conditional);
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.DOUBLE);
output.setDimensions(getFirstExpr().getOutput().getDim2(), 1);
output.setBlockDimensions(0, 0);
break;
case INVERSE:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.DOUBLE);
Identifier in = getFirstExpr().getOutput();
if(in.dimsKnown() && in.getDim1() != in.getDim2())
raiseValidateError("Input to inv() must be square matrix -- given: a " + in.getDim1() + "x" + in.getDim2() + " matrix.", conditional);
output.setDimensions(in.getDim1(), in.getDim2());
output.setBlockDimensions(in.getRowsInBlock(), in.getColumnsInBlock());
break;
case CHOLESKY:
{
// A = L%*%t(L) where L is the lower triangular matrix
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.DOUBLE);
Identifier inA = getFirstExpr().getOutput();
if(inA.dimsKnown() && inA.getDim1() != inA.getDim2())
raiseValidateError("Input to cholesky() must be square matrix -- given: a " + inA.getDim1() + "x" + inA.getDim2() + " matrix.", conditional);
output.setDimensions(inA.getDim1(), inA.getDim2());
output.setBlockDimensions(inA.getRowsInBlock(), inA.getColumnsInBlock());
break;
}
case OUTER:
Identifier id2 = this.getSecondExpr().getOutput();
//check input types and characteristics
checkNumParameters(3);
checkMatrixParam(getFirstExpr());
checkMatrixParam(getSecondExpr());
checkScalarParam(getThirdExpr());
checkValueTypeParam(getThirdExpr(), ValueType.STRING);
if( id.getDim2() > 1 || id2.getDim1()>1 ) {
raiseValidateError("Outer vector operations require a common dimension of one: " +
id.getDim1()+"x"+id.getDim2()+" o "+id2.getDim1()+"x"+id2.getDim2()+".", false);
}
//set output characteristics
output.setDataType(id.getDataType());
output.setDimensions(id.getDim1(), id2.getDim2());
output.setBlockDimensions(id.getRowsInBlock(), id.getColumnsInBlock());
break;
case CONV2D:
case CONV2D_BACKWARD_FILTER:
case CONV2D_BACKWARD_DATA:
case MAX_POOL:
case AVG_POOL:
case MAX_POOL_BACKWARD:
{
// At DML level:
// output = conv2d(input, filter, input_shape=[1, 3, 2, 2], filter_shape=[1, 3, 2, 2],
// strides=[1, 1], padding=[1,1])
//
// Converted to following in constructor (only supported NCHW):
// output = conv2d(input, filter, stride1, stride2, padding1,padding2,
// input_shape1, input_shape2, input_shape3, input_shape4,
// filter_shape1, filter_shape2, filter_shape3, filter_shape4)
//
// Similarly,
// conv2d_backward_filter and conv2d_backward_data
Expression input = _args[0]; // For conv2d_backward_filter, this is input and for conv2d_backward_data, this is filter
Expression filter = null;
if(!(this.getOpCode() == BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)) {
filter = _args[1]; // For conv2d_backward functions, this is dout
checkMatrixParam(filter);
}
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.DOUBLE);
output.setBlockDimensions(input.getOutput().getRowsInBlock(), input.getOutput().getColumnsInBlock());
// stride1, stride2, padding1, padding2, numImg, numChannels, imgSize, imgSize,
// filter_shape1=1, filter_shape2=1, filterSize/poolSize1, filterSize/poolSize1
if(this.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD ||
this.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_DATA) {
output.setDimensions(input.getOutput().getDim1(), input.getOutput().getDim2());
}
else if(this.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_FILTER) {
output.setDimensions(filter.getOutput().getDim1(), filter.getOutput().getDim2());
}
else if(this.getOpCode() == BuiltinFunctionOp.CONV2D || this.getOpCode() == BuiltinFunctionOp.MAX_POOL) {
try {
int start = 1;
if(this.getOpCode() == BuiltinFunctionOp.CONV2D) {
start = 2;
}
long stride_h = (long) getDoubleValue(_args[start++]);
long stride_w = (long) getDoubleValue(_args[start++]);
long pad_h = (long) getDoubleValue(_args[start++]);
long pad_w = (long) getDoubleValue(_args[start++]);
start++;
long C = (long) getDoubleValue(_args[start++]);
long H = (long) getDoubleValue(_args[start++]);
long W = (long) getDoubleValue(_args[start++]);
long K = -1;
if(this.getOpCode() == BuiltinFunctionOp.CONV2D) {
K = (long) getDoubleValue(_args[start]);
}
start++; start++;
long R = (long) getDoubleValue(_args[start++]);
long S = (long) getDoubleValue(_args[start++]);
long P = ConvolutionUtils.getP(H, R, stride_h, pad_h);
long Q = ConvolutionUtils.getP(W, S, stride_w, pad_w);
if(this.getOpCode() == BuiltinFunctionOp.CONV2D)
output.setDimensions(input.getOutput().getDim1(), K*P*Q);
else
output.setDimensions(input.getOutput().getDim1(), C*P*Q);
}
catch(Exception e) {
output.setDimensions(input.getOutput().getDim1(), -1); // To make sure that output dimensions are not incorrect
}
}
else
throw new LanguageException("Unsupported op: " + this.getOpCode());
checkMatrixParam(input);
break;
}
default:
if (this.isMathFunction()) {
// datatype and dimensions are same as this.getExpr()
if (this.getOpCode() == BuiltinFunctionOp.ABS) {
output.setValueType(getFirstExpr().getOutput().getValueType());
} else {
output.setValueType(ValueType.DOUBLE);
}
checkMathFunctionParam();
output.setDataType(id.getDataType());
output.setDimensions(id.getDim1(), id.getDim2());
output.setBlockDimensions(id.getRowsInBlock(), id.getColumnsInBlock());
}
else {
// always unconditional (because unsupported operation)
BuiltinFunctionOp op = getOpCode();
if( op==BuiltinFunctionOp.EIGEN || op==BuiltinFunctionOp.LU || op==BuiltinFunctionOp.QR )
raiseValidateError("Function "+op+" needs to be called with multi-return assignment.", false, LanguageErrorCodes.INVALID_PARAMETERS);
else
raiseValidateError("Unsupported function "+op, false, LanguageErrorCodes.INVALID_PARAMETERS);
}
}
return;
}
private void expandArguments() {
if ( _args == null ) {
_args = new Expression[1];
return;
}
Expression [] temp = _args.clone();
_args = new Expression[_args.length + 1];
System.arraycopy(temp, 0, _args, 0, temp.length);
}
@Override
public boolean multipleReturns() {
switch(_opcode) {
case QR:
case LU:
case EIGEN:
return true;
default:
return false;
}
}
/**
*
* @param expr
* @return
*/
private boolean isConstant(Expression expr) {
return ( expr != null && expr instanceof ConstIdentifier );
}
/**
*
* @param expr
* @return
* @throws LanguageException
*/
private double getDoubleValue(Expression expr)
throws LanguageException
{
if ( expr instanceof DoubleIdentifier )
return ((DoubleIdentifier)expr).getValue();
else if ( expr instanceof IntIdentifier)
return ((IntIdentifier)expr).getValue();
else
throw new LanguageException("Expecting a numeric value.");
}
private boolean isMathFunction() {
switch (this.getOpCode()) {
case COS:
case SIN:
case TAN:
case ACOS:
case ASIN:
case ATAN:
case SIGN:
case SQRT:
case ABS:
case LOG:
case EXP:
case ROUND:
case CEIL:
case FLOOR:
case MEDIAN:
return true;
default:
return false;
}
}
private void checkMathFunctionParam() throws LanguageException {
switch (this.getOpCode()) {
case COS:
case SIN:
case TAN:
case ACOS:
case ASIN:
case ATAN:
case SIGN:
case SQRT:
case ABS:
case EXP:
case ROUND:
case CEIL:
case FLOOR:
case MEDIAN:
checkNumParameters(1);
break;
case LOG:
if (getSecondExpr() != null) {
checkNumParameters(2);
}
else {
checkNumParameters(1);
}
break;
default:
//always unconditional
raiseValidateError("Unknown math function "+ this.getOpCode(), false);
}
}
public String toString() {
StringBuilder sb = new StringBuilder(_opcode.toString() + "(" + _args[0].toString());
for(int i=1; i < _args.length; i++) {
sb.append(",");
sb.append(_args[i].toString());
}
sb.append(")");
return sb.toString();
}
@Override
// third part of expression IS NOT a variable -- it is the OP to be applied
public VariableSet variablesRead() {
VariableSet result = new VariableSet();
for(int i=0; i<_args.length; i++) {
result.addVariables(_args[i].variablesRead());
}
return result;
}
@Override
public VariableSet variablesUpdated() {
VariableSet result = new VariableSet();
// result.addVariables(_first.variablesUpdated());
return result;
}
/**
*
* @param count
* @throws LanguageException
*/
protected void checkNumParameters(int count) //always unconditional
throws LanguageException
{
if (getFirstExpr() == null){
raiseValidateError("Missing parameter for function "+ this.getOpCode(), false, LanguageErrorCodes.INVALID_PARAMETERS);
}
if (((count == 1) && (getSecondExpr()!= null || getThirdExpr() != null)) ||
((count == 2) && (getThirdExpr() != null))){
raiseValidateError("Invalid number of parameters for function "+ this.getOpCode(), false, LanguageErrorCodes.INVALID_PARAMETERS);
}
else if (((count == 2) && (getSecondExpr() == null)) ||
((count == 3) && (getSecondExpr() == null || getThirdExpr() == null))){
raiseValidateError( "Missing parameter for function "+this.getOpCode(), false, LanguageErrorCodes.INVALID_PARAMETERS);
}
}
/**
*
* @param e
* @throws LanguageException
*/
protected void checkMatrixParam(Expression e) //always unconditional
throws LanguageException
{
if (e.getOutput().getDataType() != DataType.MATRIX) {
raiseValidateError("Expecting matrix parameter for function "+ this.getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
}
}
/**
*
* @param e
* @throws LanguageException
*/
protected void checkMatrixFrameParam(Expression e) //always unconditional
throws LanguageException
{
if (e.getOutput().getDataType() != DataType.MATRIX && e.getOutput().getDataType() != DataType.FRAME) {
raiseValidateError("Expecting matrix or frame parameter for function "+ getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
}
}
/**
*
* @param e
* @throws LanguageException
*/
protected void checkMatrixScalarParam(Expression e) //always unconditional
throws LanguageException
{
if (e.getOutput().getDataType() != DataType.MATRIX && e.getOutput().getDataType() != DataType.SCALAR) {
raiseValidateError("Expecting matrix or scalar parameter for function "+ getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
}
}
/**
*
* @param e
* @throws LanguageException
*/
private void checkScalarParam(Expression e) //always unconditional
throws LanguageException
{
if (e.getOutput().getDataType() != DataType.SCALAR) {
raiseValidateError("Expecting scalar parameter for function " + this.getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
}
}
/**
*
* @param e
* @throws LanguageException
*/
private void checkScalarFrameParam(Expression e) //always unconditional
throws LanguageException
{
if (e.getOutput().getDataType() != DataType.SCALAR && e.getOutput().getDataType() != DataType.FRAME) {
raiseValidateError("Expecting scalar parameter for function " + this.getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
}
}
/**
*
* @param e
* @param vt
* @throws LanguageException
*/
private void checkValueTypeParam(Expression e, ValueType vt) //always unconditional
throws LanguageException
{
if (e.getOutput().getValueType() != vt) {
raiseValidateError("Expecting parameter of different value type " + this.getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
}
}
private boolean is1DMatrix(Expression e) {
return (e.getOutput().getDim1() == 1 || e.getOutput().getDim2() == 1 );
}
private boolean dimsKnown(Expression e) {
return (e.getOutput().getDim1() != -1 && e.getOutput().getDim2() != -1);
}
/**
*
* @param e
* @throws LanguageException
*/
private void check1DMatrixParam(Expression e) //always unconditional
throws LanguageException
{
checkMatrixParam(e);
// throw an exception, when e's output is NOT a one-dimensional matrix
// the check must be performed only when the dimensions are known at compilation time
if ( dimsKnown(e) && !is1DMatrix(e)) {
raiseValidateError("Expecting one-dimensional matrix parameter for function "
+ this.getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
}
}
/**
*
* @param expr1
* @param expr2
* @throws LanguageException
*/
private void checkMatchingDimensions(Expression expr1, Expression expr2)
throws LanguageException
{
checkMatchingDimensions(expr1, expr2, false);
}
/**
*
* @param expr1
* @param expr2
* @throws LanguageException
*/
private void checkMatchingDimensions(Expression expr1, Expression expr2, boolean allowsMV)
throws LanguageException
{
if (expr1 != null && expr2 != null) {
// if any matrix has unknown dimensions, simply return
if( expr1.getOutput().getDim1() == -1 || expr2.getOutput().getDim1() == -1
||expr1.getOutput().getDim2() == -1 || expr2.getOutput().getDim2() == -1 )
{
return;
}
else if( (!allowsMV && expr1.getOutput().getDim1() != expr2.getOutput().getDim1())
|| (allowsMV && expr1.getOutput().getDim1() != expr2.getOutput().getDim1() && expr2.getOutput().getDim1() != 1)
|| (!allowsMV && expr1.getOutput().getDim2() != expr2.getOutput().getDim2())
|| (allowsMV && expr1.getOutput().getDim2() != expr2.getOutput().getDim2() && expr2.getOutput().getDim2() != 1) )
{
raiseValidateError("Mismatch in matrix dimensions of parameters for function "
+ this.getOpCode(), false, LanguageErrorCodes.INVALID_PARAMETERS);
}
}
}
/**
*
* @throws LanguageException
*/
private void checkMatchingDimensionsQuantile()
throws LanguageException
{
if (getFirstExpr().getOutput().getDim1() != getSecondExpr().getOutput().getDim1()) {
raiseValidateError("Mismatch in matrix dimensions for "
+ this.getOpCode(), false, LanguageErrorCodes.INVALID_PARAMETERS);
}
}
public static BuiltinFunctionExpression getBuiltinFunctionExpression(
String functionName, ArrayList paramExprsPassed,
String filename, int blp, int bcp, int elp, int ecp) {
if (functionName == null || paramExprsPassed == null)
return null;
// check if the function name is built-in function
// (assign built-in function op if function is built-in
Expression.BuiltinFunctionOp bifop = null;
if (functionName.equals("avg"))
bifop = Expression.BuiltinFunctionOp.MEAN;
else if (functionName.equals("cos"))
bifop = Expression.BuiltinFunctionOp.COS;
else if (functionName.equals("sin"))
bifop = Expression.BuiltinFunctionOp.SIN;
else if (functionName.equals("tan"))
bifop = Expression.BuiltinFunctionOp.TAN;
else if (functionName.equals("acos"))
bifop = Expression.BuiltinFunctionOp.ACOS;
else if (functionName.equals("asin"))
bifop = Expression.BuiltinFunctionOp.ASIN;
else if (functionName.equals("atan"))
bifop = Expression.BuiltinFunctionOp.ATAN;
else if (functionName.equals("diag"))
bifop = Expression.BuiltinFunctionOp.DIAG;
else if (functionName.equals("exp"))
bifop = Expression.BuiltinFunctionOp.EXP;
else if (functionName.equals("abs"))
bifop = Expression.BuiltinFunctionOp.ABS;
else if (functionName.equals("min"))
bifop = Expression.BuiltinFunctionOp.MIN;
else if (functionName.equals("max"))
bifop = Expression.BuiltinFunctionOp.MAX;
//NOTE: pmin and pmax are just kept for compatibility to R
// min and max is capable of handling all unary and binary
// operations (in contrast to R)
else if (functionName.equals("pmin"))
bifop = Expression.BuiltinFunctionOp.MIN;
else if (functionName.equals("pmax"))
bifop = Expression.BuiltinFunctionOp.MAX;
else if (functionName.equals("ppred"))
bifop = Expression.BuiltinFunctionOp.PPRED;
else if (functionName.equals("log"))
bifop = Expression.BuiltinFunctionOp.LOG;
else if (functionName.equals("length"))
bifop = Expression.BuiltinFunctionOp.LENGTH;
else if (functionName.equals("ncol"))
bifop = Expression.BuiltinFunctionOp.NCOL;
else if (functionName.equals("nrow"))
bifop = Expression.BuiltinFunctionOp.NROW;
else if (functionName.equals("sign"))
bifop = Expression.BuiltinFunctionOp.SIGN;
else if (functionName.equals("sqrt"))
bifop = Expression.BuiltinFunctionOp.SQRT;
else if (functionName.equals("sum"))
bifop = Expression.BuiltinFunctionOp.SUM;
else if (functionName.equals("mean"))
bifop = Expression.BuiltinFunctionOp.MEAN;
else if (functionName.equals("sd"))
bifop = Expression.BuiltinFunctionOp.SD;
else if (functionName.equals("var"))
bifop = Expression.BuiltinFunctionOp.VAR;
else if (functionName.equals("trace"))
bifop = Expression.BuiltinFunctionOp.TRACE;
else if (functionName.equals("t"))
bifop = Expression.BuiltinFunctionOp.TRANS;
else if (functionName.equals("rev"))
bifop = Expression.BuiltinFunctionOp.REV;
else if (functionName.equals("cbind") || functionName.equals("append"))
bifop = Expression.BuiltinFunctionOp.CBIND;
else if (functionName.equals("rbind"))
bifop = Expression.BuiltinFunctionOp.RBIND;
else if (functionName.equals("range"))
bifop = Expression.BuiltinFunctionOp.RANGE;
else if (functionName.equals("prod"))
bifop = Expression.BuiltinFunctionOp.PROD;
else if (functionName.equals("rowSums"))
bifop = Expression.BuiltinFunctionOp.ROWSUM;
else if (functionName.equals("colSums"))
bifop = Expression.BuiltinFunctionOp.COLSUM;
else if (functionName.equals("rowMins"))
bifop = Expression.BuiltinFunctionOp.ROWMIN;
else if (functionName.equals("colMins"))
bifop = Expression.BuiltinFunctionOp.COLMIN;
else if (functionName.equals("rowMaxs"))
bifop = Expression.BuiltinFunctionOp.ROWMAX;
else if (functionName.equals("rowIndexMax"))
bifop = Expression.BuiltinFunctionOp.ROWINDEXMAX;
else if (functionName.equals("rowIndexMin"))
bifop = Expression.BuiltinFunctionOp.ROWINDEXMIN;
else if (functionName.equals("colMaxs"))
bifop = Expression.BuiltinFunctionOp.COLMAX;
else if (functionName.equals("rowMeans"))
bifop = Expression.BuiltinFunctionOp.ROWMEAN;
else if (functionName.equals("colMeans"))
bifop = Expression.BuiltinFunctionOp.COLMEAN;
else if (functionName.equals("rowSds"))
bifop = Expression.BuiltinFunctionOp.ROWSD;
else if (functionName.equals("colSds"))
bifop = Expression.BuiltinFunctionOp.COLSD;
else if (functionName.equals("rowVars"))
bifop = Expression.BuiltinFunctionOp.ROWVAR;
else if (functionName.equals("colVars"))
bifop = Expression.BuiltinFunctionOp.COLVAR;
else if (functionName.equals("cummax"))
bifop = Expression.BuiltinFunctionOp.CUMMAX;
else if (functionName.equals("cummin"))
bifop = Expression.BuiltinFunctionOp.CUMMIN;
else if (functionName.equals("cumprod"))
bifop = Expression.BuiltinFunctionOp.CUMPROD;
else if (functionName.equals("cumsum"))
bifop = Expression.BuiltinFunctionOp.CUMSUM;
//'castAsScalar' for backwards compatibility
else if (functionName.equals("as.scalar") || functionName.equals("castAsScalar"))
bifop = Expression.BuiltinFunctionOp.CAST_AS_SCALAR;
else if (functionName.equals("as.matrix"))
bifop = Expression.BuiltinFunctionOp.CAST_AS_MATRIX;
else if (functionName.equals("as.frame"))
bifop = Expression.BuiltinFunctionOp.CAST_AS_FRAME;
else if (functionName.equals("as.double"))
bifop = Expression.BuiltinFunctionOp.CAST_AS_DOUBLE;
else if (functionName.equals("as.integer"))
bifop = Expression.BuiltinFunctionOp.CAST_AS_INT;
else if (functionName.equals("as.logical")) //alternative: as.boolean
bifop = Expression.BuiltinFunctionOp.CAST_AS_BOOLEAN;
else if (functionName.equals("quantile"))
bifop= Expression.BuiltinFunctionOp.QUANTILE;
else if (functionName.equals("interQuantile"))
bifop= Expression.BuiltinFunctionOp.INTERQUANTILE;
else if (functionName.equals("interQuartileMean"))
bifop= Expression.BuiltinFunctionOp.IQM;
//'ctable' for backwards compatibility
else if (functionName.equals("table") || functionName.equals("ctable"))
bifop = Expression.BuiltinFunctionOp.TABLE;
else if (functionName.equals("round"))
bifop = Expression.BuiltinFunctionOp.ROUND;
//'centralMoment' for backwards compatibility
else if (functionName.equals("moment") || functionName.equals("centralMoment"))
bifop = Expression.BuiltinFunctionOp.MOMENT;
else if (functionName.equals("cov"))
bifop = Expression.BuiltinFunctionOp.COV;
else if (functionName.equals("seq"))
bifop = Expression.BuiltinFunctionOp.SEQ;
else if (functionName.equals("qr"))
bifop = Expression.BuiltinFunctionOp.QR;
else if (functionName.equals("lu"))
bifop = Expression.BuiltinFunctionOp.LU;
else if (functionName.equals("eigen"))
bifop = Expression.BuiltinFunctionOp.EIGEN;
else if (functionName.equals("conv2d"))
bifop = Expression.BuiltinFunctionOp.CONV2D;
else if (functionName.equals("conv2d_backward_filter"))
bifop = Expression.BuiltinFunctionOp.CONV2D_BACKWARD_FILTER;
else if (functionName.equals("conv2d_backward_data"))
bifop = Expression.BuiltinFunctionOp.CONV2D_BACKWARD_DATA;
else if (functionName.equals("max_pool"))
bifop = Expression.BuiltinFunctionOp.MAX_POOL;
else if (functionName.equals("max_pool_backward"))
bifop = Expression.BuiltinFunctionOp.MAX_POOL_BACKWARD;
else if (functionName.equals("avg_pool"))
bifop = Expression.BuiltinFunctionOp.AVG_POOL;
else if (functionName.equals("solve"))
bifop = Expression.BuiltinFunctionOp.SOLVE;
else if (functionName.equals("ceil"))
bifop = Expression.BuiltinFunctionOp.CEIL;
else if (functionName.equals("floor"))
bifop = Expression.BuiltinFunctionOp.FLOOR;
else if (functionName.equals("median"))
bifop = Expression.BuiltinFunctionOp.MEDIAN;
else if (functionName.equals("inv"))
bifop = Expression.BuiltinFunctionOp.INVERSE;
else if (functionName.equals("cholesky"))
bifop = Expression.BuiltinFunctionOp.CHOLESKY;
else if (functionName.equals("sample"))
bifop = Expression.BuiltinFunctionOp.SAMPLE;
else if ( functionName.equals("outer") )
bifop = Expression.BuiltinFunctionOp.OUTER;
else
return null;
BuiltinFunctionExpression retVal = new BuiltinFunctionExpression(bifop, paramExprsPassed,
filename, blp, bcp, elp, ecp);
return retVal;
} // end method getBuiltinFunctionExpression
/**
* Convert a value type (double, int, or boolean) to a built-in function operator.
*
* @param vt Value type ({@code ValueType.DOUBLE}, {@code ValueType.INT}, or {@code ValueType.BOOLEAN}).
* @return Built-in function operator ({@code BuiltinFunctionOp.AS_DOUBLE},
* {@code BuiltinFunctionOp.AS_INT}, or {@code BuiltinFunctionOp.AS_BOOLEAN}).
* @throws LanguageException thrown if ValueType not accepted
*/
public static BuiltinFunctionOp getValueTypeCastOperator( ValueType vt )
throws LanguageException
{
switch( vt )
{
case DOUBLE:
return BuiltinFunctionOp.CAST_AS_DOUBLE;
case INT:
return BuiltinFunctionOp.CAST_AS_INT;
case BOOLEAN:
return BuiltinFunctionOp.CAST_AS_BOOLEAN;
default:
throw new LanguageException("No cast for value type "+vt);
}
}
}