All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.sparkml.ExpressionTranslator Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2018 Villu Ruusmann
 *
 * This file is part of JPMML-SparkML
 *
 * JPMML-SparkML is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-SparkML is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-SparkML.  If not, see .
 */
package org.jpmml.sparkml;

import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;

import org.apache.spark.sql.catalyst.expressions.Abs;
import org.apache.spark.sql.catalyst.expressions.Acos;
import org.apache.spark.sql.catalyst.expressions.Add;
import org.apache.spark.sql.catalyst.expressions.Alias;
import org.apache.spark.sql.catalyst.expressions.And;
import org.apache.spark.sql.catalyst.expressions.Asin;
import org.apache.spark.sql.catalyst.expressions.Atan;
import org.apache.spark.sql.catalyst.expressions.AttributeReference;
import org.apache.spark.sql.catalyst.expressions.BinaryArithmetic;
import org.apache.spark.sql.catalyst.expressions.BinaryComparison;
import org.apache.spark.sql.catalyst.expressions.BinaryMathExpression;
import org.apache.spark.sql.catalyst.expressions.BinaryOperator;
import org.apache.spark.sql.catalyst.expressions.CaseWhen;
import org.apache.spark.sql.catalyst.expressions.Cast;
import org.apache.spark.sql.catalyst.expressions.Ceil;
import org.apache.spark.sql.catalyst.expressions.Concat;
import org.apache.spark.sql.catalyst.expressions.Cos;
import org.apache.spark.sql.catalyst.expressions.Cosh;
import org.apache.spark.sql.catalyst.expressions.Divide;
import org.apache.spark.sql.catalyst.expressions.EqualTo;
import org.apache.spark.sql.catalyst.expressions.Exp;
import org.apache.spark.sql.catalyst.expressions.Expm1;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.Floor;
import org.apache.spark.sql.catalyst.expressions.GreaterThan;
import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.Greatest;
import org.apache.spark.sql.catalyst.expressions.Hypot;
import org.apache.spark.sql.catalyst.expressions.If;
import org.apache.spark.sql.catalyst.expressions.In;
import org.apache.spark.sql.catalyst.expressions.IsNaN;
import org.apache.spark.sql.catalyst.expressions.IsNotNull;
import org.apache.spark.sql.catalyst.expressions.IsNull;
import org.apache.spark.sql.catalyst.expressions.Least;
import org.apache.spark.sql.catalyst.expressions.Length;
import org.apache.spark.sql.catalyst.expressions.LessThan;
import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.Literal;
import org.apache.spark.sql.catalyst.expressions.Log;
import org.apache.spark.sql.catalyst.expressions.Log10;
import org.apache.spark.sql.catalyst.expressions.Log1p;
import org.apache.spark.sql.catalyst.expressions.Lower;
import org.apache.spark.sql.catalyst.expressions.Multiply;
import org.apache.spark.sql.catalyst.expressions.Not;
import org.apache.spark.sql.catalyst.expressions.Or;
import org.apache.spark.sql.catalyst.expressions.Pow;
import org.apache.spark.sql.catalyst.expressions.RLike;
import org.apache.spark.sql.catalyst.expressions.RegExpReplace;
import org.apache.spark.sql.catalyst.expressions.Rint;
import org.apache.spark.sql.catalyst.expressions.Sin;
import org.apache.spark.sql.catalyst.expressions.Sinh;
import org.apache.spark.sql.catalyst.expressions.Sqrt;
import org.apache.spark.sql.catalyst.expressions.StringReplace;
import org.apache.spark.sql.catalyst.expressions.StringTrim;
import org.apache.spark.sql.catalyst.expressions.Substring;
import org.apache.spark.sql.catalyst.expressions.Subtract;
import org.apache.spark.sql.catalyst.expressions.Tan;
import org.apache.spark.sql.catalyst.expressions.Tanh;
import org.apache.spark.sql.catalyst.expressions.UnaryExpression;
import org.apache.spark.sql.catalyst.expressions.UnaryMinus;
import org.apache.spark.sql.catalyst.expressions.UnaryPositive;
import org.apache.spark.sql.catalyst.expressions.Upper;
import org.apache.spark.sql.types.Decimal;
import org.dmg.pmml.Apply;
import org.dmg.pmml.Constant;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.HasDataType;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMMLFunctions;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.IfElseBuilder;
import org.jpmml.converter.TypeUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.visitors.ExpressionCompactor;
import scala.Option;
import scala.Tuple2;
import scala.collection.JavaConversions;

public class ExpressionTranslator {

	private SparkMLEncoder encoder = null;


	private ExpressionTranslator(SparkMLEncoder encoder){
		setEncoder(encoder);
	}

	public SparkMLEncoder getEncoder(){
		return this.encoder;
	}

	private void setEncoder(SparkMLEncoder encoder){
		this.encoder = Objects.requireNonNull(encoder);
	}

	static
	public org.dmg.pmml.Expression translate(SparkMLEncoder encoder, Expression expression){
		return translate(encoder, expression, true);
	}

	static
	public org.dmg.pmml.Expression translate(SparkMLEncoder encoder, Expression expression, boolean compact){
		ExpressionTranslator expressionTranslator = new ExpressionTranslator(encoder);

		org.dmg.pmml.Expression pmmlExpression = expressionTranslator.translateInternal(expression);

		if(compact){
			ExpressionCompactor expressionCompactor = new ExpressionCompactor();

			expressionCompactor.applyTo(pmmlExpression);
		}

		return pmmlExpression;
	}

	private org.dmg.pmml.Expression translateInternal(Expression expression){
		SparkMLEncoder encoder = getEncoder();

		if(expression instanceof Alias){
			Alias alias = (Alias)expression;

			String name = alias.name();
			Expression child = alias.child();

			org.dmg.pmml.Expression pmmlExpression = translateInternal(child);

			return new AliasExpression(name, pmmlExpression);
		} // End if

		if(expression instanceof AttributeReference){
			AttributeReference attributeReference = (AttributeReference)expression;

			String name = attributeReference.name();

			return new FieldRef(name);
		} else

		if(expression instanceof BinaryMathExpression){
			BinaryMathExpression binaryMathExpression = (BinaryMathExpression)expression;

			Expression left = binaryMathExpression.left();
			Expression right = binaryMathExpression.right();

			String function;

			if(binaryMathExpression instanceof Hypot){
				function = PMMLFunctions.HYPOT;
			} else

			if(binaryMathExpression instanceof Pow){
				function = PMMLFunctions.POW;
			} else

			{
				throw new IllegalArgumentException(formatMessage(binaryMathExpression));
			}

			return ExpressionUtil.createApply(function, translateInternal(left), translateInternal(right));
		} else

		if(expression instanceof BinaryOperator){
			BinaryOperator binaryOperator = (BinaryOperator)expression;

			String symbol = binaryOperator.symbol();

			Expression left = binaryOperator.left();
			Expression right = binaryOperator.right();

			String function;

			if(expression instanceof And || expression instanceof Or){

				switch(symbol){
					case "&&":
						function = PMMLFunctions.AND;
						break;
					case "||":
						function = PMMLFunctions.OR;
						break;
					default:
						throw new IllegalArgumentException(formatMessage(binaryOperator));
				}
			} else

			if(expression instanceof Add || expression instanceof Divide || expression instanceof Multiply || expression instanceof Subtract){
				BinaryArithmetic binaryArithmetic = (BinaryArithmetic)binaryOperator;

				switch(symbol){
					case "+":
						function = PMMLFunctions.ADD;
						break;
					case "/":
						function = PMMLFunctions.DIVIDE;
						break;
					case "*":
						function = PMMLFunctions.MULTIPLY;
						break;
					case "-":
						function = PMMLFunctions.SUBTRACT;
						break;
					default:
						throw new IllegalArgumentException(formatMessage(binaryArithmetic));
				}
			} else

			if(expression instanceof EqualTo || expression instanceof GreaterThan || expression instanceof GreaterThanOrEqual || expression instanceof LessThan || expression instanceof LessThanOrEqual){
				BinaryComparison binaryComparison = (BinaryComparison)binaryOperator;

				switch(symbol){
					case "=":
						function = PMMLFunctions.EQUAL;
						break;
					case ">":
						function = PMMLFunctions.GREATERTHAN;
						break;
					case ">=":
						function = PMMLFunctions.GREATEROREQUAL;
						break;
					case "<":
						function = PMMLFunctions.LESSTHAN;
						break;
					case "<=":
						function = PMMLFunctions.LESSOREQUAL;
						break;
					default:
						throw new IllegalArgumentException(formatMessage(binaryComparison));
				}
			} else

			{
				throw new IllegalArgumentException(formatMessage(binaryOperator));
			}

			return ExpressionUtil.createApply(function, translateInternal(left), translateInternal(right));
		} else

		if(expression instanceof CaseWhen){
			CaseWhen caseWhen = (CaseWhen)expression;

			List> branches = JavaConversions.seqAsJavaList(caseWhen.branches());

			Option elseValue = caseWhen.elseValue();

			IfElseBuilder applyBuilder = new IfElseBuilder();

			Iterator> branchIt = branches.iterator();

			do {
				Tuple2 branch = branchIt.next();

				Expression predicate = branch._1();
				Expression value = branch._2();

				applyBuilder.add(translateInternal(predicate), translateInternal(value));
			} while(branchIt.hasNext());

			if(elseValue.isDefined()){
				Expression value = elseValue.get();

				applyBuilder.terminate(translateInternal(value));
			}

			return applyBuilder.build();
		} else

		if(expression instanceof Cast){
			Cast cast = (Cast)expression;

			Expression child = cast.child();

			org.dmg.pmml.Expression pmmlExpression = translateInternal(child);

			DataType dataType = DatasetUtil.translateDataType(cast.dataType());

			if(pmmlExpression instanceof HasDataType){
				HasDataType hasDataType = (HasDataType)pmmlExpression;

				hasDataType.setDataType(dataType);

				return pmmlExpression;
			} else

			{
				String name;

				if(pmmlExpression instanceof AliasExpression){
					AliasExpression aliasExpression = (AliasExpression)pmmlExpression;

					name = aliasExpression.getName();
				} else

				{
					name = FieldNameUtil.create(dataType, String.valueOf(child));
				}

				OpType opType = TypeUtil.getOpType(dataType);

				pmmlExpression = AliasExpression.unwrap(pmmlExpression);

				DerivedField derivedField = encoder.createDerivedField(name, opType, dataType, pmmlExpression);

				return new FieldRef(derivedField);
			}
		} else

		if(expression instanceof Concat){
			Concat concat = (Concat)expression;

			List children = JavaConversions.seqAsJavaList(concat.children());

			Apply apply = ExpressionUtil.createApply(PMMLFunctions.CONCAT);

			for(Expression child : children){
				apply.addExpressions(translateInternal(child));
			}

			return apply;
		} else

		if(expression instanceof Greatest){
			Greatest greatest = (Greatest)expression;

			List children = JavaConversions.seqAsJavaList(greatest.children());

			Apply apply = ExpressionUtil.createApply(PMMLFunctions.MAX);

			for(Expression child : children){
				apply.addExpressions(translateInternal(child));
			}

			return apply;
		} else

		if(expression instanceof If){
			If _if = (If)expression;

			Expression predicate = _if.predicate();

			Expression trueValue = _if.trueValue();
			Expression falseValue = _if.falseValue();

			return ExpressionUtil.createApply(PMMLFunctions.IF,
				translateInternal(predicate),
				translateInternal(trueValue),
				translateInternal(falseValue)
			);
		} else

		if(expression instanceof In){
			In in = (In)expression;

			Expression value = in.value();

			List elements = JavaConversions.seqAsJavaList(in.list());

			Apply apply = ExpressionUtil.createApply(PMMLFunctions.ISIN, translateInternal(value));

			for(Expression element : elements){
				apply.addExpressions(translateInternal(element));
			}

			return apply;
		} else

		if(expression instanceof Least){
			Least least = (Least)expression;

			List children = JavaConversions.seqAsJavaList(least.children());

			Apply apply = ExpressionUtil.createApply(PMMLFunctions.MIN);

			for(Expression child : children){
				apply.addExpressions(translateInternal(child));
			}

			return apply;
		} else

		if(expression instanceof Length){
			Length length = (Length)expression;

			Expression child = length.child();

			return ExpressionUtil.createApply(PMMLFunctions.STRINGLENGTH, translateInternal(child));
		} else

		if(expression instanceof Literal){
			Literal literal = (Literal)expression;

			Object value = literal.value();
			if(value == null){
				return ExpressionUtil.createMissingConstant();
			}

			DataType dataType;

			// XXX
			if(value instanceof Decimal){
				Decimal decimal = (Decimal)value;

				dataType = DataType.STRING;

				value = decimal.toString();
			} else

			{
				dataType = DatasetUtil.translateDataType(literal.dataType());

				value = toSimpleObject(value);
			}

			return ExpressionUtil.createConstant(dataType, value);
		} else

		if(expression instanceof RegExpReplace){
			RegExpReplace regexpReplace = (RegExpReplace)expression;

			Expression subject = regexpReplace.subject();
			Expression regexp = regexpReplace.regexp();
			Expression rep = regexpReplace.rep();

			return ExpressionUtil.createApply(PMMLFunctions.REPLACE, translateInternal(subject), translateInternal(regexp), translateInternal(rep));
		} else

		if(expression instanceof RLike){
			RLike rlike = (RLike)expression;

			Expression left = rlike.left();
			Expression right = rlike.right();

			return ExpressionUtil.createApply(PMMLFunctions.MATCHES, translateInternal(left), translateInternal(right));
		} else

		if(expression instanceof StringReplace){
			StringReplace stringReplace = (StringReplace)expression;

			Expression srcExpr = stringReplace.srcExpr();
			Expression searchExpr = stringReplace.searchExpr();
			Expression replaceExpr = stringReplace.replaceExpr();

			return ExpressionUtil.createApply(PMMLFunctions.REPLACE, translateInternal(srcExpr), transformString(translateInternal(searchExpr), ExpressionTranslator::escapeSearchString), transformString(translateInternal(replaceExpr), ExpressionTranslator::escapeReplacementString));
		} else

		if(expression instanceof StringTrim){
			StringTrim stringTrim = (StringTrim)expression;

			Expression srcStr = stringTrim.srcStr();
			Option trimStr = stringTrim.trimStr();
			if(trimStr.isDefined()){
				throw new IllegalArgumentException();
			}

			return ExpressionUtil.createApply(PMMLFunctions.TRIMBLANKS, translateInternal(srcStr));
		} else

		if(expression instanceof Substring){
			Substring substring = (Substring)expression;

			Expression str = substring.str();
			Literal pos = (Literal)substring.pos();
			Literal len = (Literal)substring.len();

			int posValue = ValueUtil.asInt((Number)pos.value());
			if(posValue <= 0){
				throw new IllegalArgumentException("Expected absolute start position, got relative start position " + (pos));
			}

			int lenValue = ValueUtil.asInt((Number)len.value());

			// XXX
			lenValue = Math.min(lenValue, MAX_STRING_LENGTH);

			return ExpressionUtil.createApply(PMMLFunctions.SUBSTRING, translateInternal(str), ExpressionUtil.createConstant(posValue), ExpressionUtil.createConstant(lenValue));
		} else

		if(expression instanceof UnaryExpression){
			UnaryExpression unaryExpression = (UnaryExpression)expression;

			Expression child = unaryExpression.child();

			if(expression instanceof Abs){
				return ExpressionUtil.createApply(PMMLFunctions.ABS, translateInternal(child));
			} else

			if(expression instanceof Acos){
				return ExpressionUtil.createApply(PMMLFunctions.ACOS, translateInternal(child));
			} else

			if(expression instanceof Asin){
				return ExpressionUtil.createApply(PMMLFunctions.ASIN, translateInternal(child));
			} else

			if(expression instanceof Atan){
				return ExpressionUtil.createApply(PMMLFunctions.ATAN, translateInternal(child));
			} else

			if(expression instanceof Ceil){
				return ExpressionUtil.createApply(PMMLFunctions.CEIL, translateInternal(child));
			} else

			if(expression instanceof Cos){
				return ExpressionUtil.createApply(PMMLFunctions.COS, translateInternal(child));
			} else

			if(expression instanceof Cosh){
				return ExpressionUtil.createApply(PMMLFunctions.COSH, translateInternal(child));
			} else

			if(expression instanceof Exp){
				return ExpressionUtil.createApply(PMMLFunctions.EXP, translateInternal(child));
			} else

			if(expression instanceof Expm1){
				return ExpressionUtil.createApply(PMMLFunctions.EXPM1, translateInternal(child));
			} else

			if(expression instanceof Floor){
				return ExpressionUtil.createApply(PMMLFunctions.FLOOR, translateInternal(child));
			} else

			if(expression instanceof Log){
				return ExpressionUtil.createApply(PMMLFunctions.LN, translateInternal(child));
			} else

			if(expression instanceof Log10){
				return ExpressionUtil.createApply(PMMLFunctions.LOG10, translateInternal(child));
			} else

			if(expression instanceof Log1p){
				return ExpressionUtil.createApply(PMMLFunctions.LN1P, translateInternal(child));
			} else

			if(expression instanceof Lower){
				return ExpressionUtil.createApply(PMMLFunctions.LOWERCASE, translateInternal(child));
			} else

			if(expression instanceof IsNaN){
				// XXX
				return ExpressionUtil.createApply(PMMLFunctions.ISNOTVALID, translateInternal(child));
			} else

			if(expression instanceof IsNotNull){
				return ExpressionUtil.createApply(PMMLFunctions.ISNOTMISSING, translateInternal(child));
			} else

			if(expression instanceof IsNull){
				return ExpressionUtil.createApply(PMMLFunctions.ISMISSING, translateInternal(child));
			} else

			if(expression instanceof Not){
				 return ExpressionUtil.createApply(PMMLFunctions.NOT, translateInternal(child));
			} else

			if(expression instanceof Rint){
				return ExpressionUtil.createApply(PMMLFunctions.RINT, translateInternal(child));
			} else

			if(expression instanceof Sin){
				return ExpressionUtil.createApply(PMMLFunctions.SIN, translateInternal(child));
			} else

			if(expression instanceof Sinh){
				return ExpressionUtil.createApply(PMMLFunctions.SINH, translateInternal(child));
			} else

			if(expression instanceof Sqrt){
				return ExpressionUtil.createApply(PMMLFunctions.SQRT, translateInternal(child));
			} else

			if(expression instanceof Tan){
				return ExpressionUtil.createApply(PMMLFunctions.TAN, translateInternal(child));
			} else

			if(expression instanceof Tanh){
				return ExpressionUtil.createApply(PMMLFunctions.TANH, translateInternal(child));
			} else

			if(expression instanceof UnaryMinus){
				return ExpressionUtil.toNegative(translateInternal(child));
			} else

			if(expression instanceof UnaryPositive){
				return translateInternal(child);
			} else

			if(expression instanceof Upper){
				return ExpressionUtil.createApply(PMMLFunctions.UPPERCASE, translateInternal(child));
			} else

			{
				throw new IllegalArgumentException(formatMessage(unaryExpression));
			}
		} else

		{
			throw new IllegalArgumentException(formatMessage(expression));
		}
	}

	static
	private String escapeSearchString(String string){
		return escape(string, "<([{\\^-=$!|]})?*+.>");
	}

	static
	private String escapeReplacementString(String string){
		return escape(string, "\\$");
	}

	static
	private String escape(String string, String specialCharacters){
		StringBuilder sb = new StringBuilder();

		for(int i = 0; i < string.length(); i++){
			char c = string.charAt(i);

			if(specialCharacters.indexOf(c) > -1){
				sb.append('\\');
			}

			sb.append(c);
		}

		return sb.toString();
	}

	static
	private Constant transformString(org.dmg.pmml.Expression pmmlExpression, Function function){
		Constant constant = (Constant)pmmlExpression;

		if(constant.getDataType() != DataType.STRING){
			throw new IllegalArgumentException();
		}

		constant.setValue(function.apply((String)constant.getValue()));

		return constant;
	}

	static
	private Object toSimpleObject(Object value){
		Class clazz = value.getClass();

		if(!(ExpressionTranslator.javaLangPackage).equals(clazz.getPackage())){
			return value.toString();
		}

		return value;
	}

	static
	private String formatMessage(Expression expression){

		if(expression == null){
			return null;
		}

		return "Spark SQL function \'" + String.valueOf(expression) + "\' (class " + (expression.getClass()).getName() + ") is not supported";
	}

	private static final Package javaLangPackage = Package.getPackage("java.lang");

	private static final int MAX_STRING_LENGTH = 65536;
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy