org.jpmml.translator.TranslationContext Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-transpiler Show documentation
Show all versions of pmml-transpiler Show documentation
JPMML class model transpiler
/*
* Copyright (c) 2017 Villu Ruusmann
*
* This file is part of JPMML-Transpiler
*
* JPMML-Transpiler is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-Transpiler is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-Transpiler. If not, see .
*/
package org.jpmml.translator;
import java.lang.reflect.Constructor;
import java.lang.reflect.TypeVariable;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.xml.namespace.QName;
import com.google.common.collect.Iterables;
import com.sun.codemodel.JBlock;
import com.sun.codemodel.JCase;
import com.sun.codemodel.JClass;
import com.sun.codemodel.JCodeModel;
import com.sun.codemodel.JDefinedClass;
import com.sun.codemodel.JExpr;
import com.sun.codemodel.JExpression;
import com.sun.codemodel.JFieldRef;
import com.sun.codemodel.JFieldVar;
import com.sun.codemodel.JInvocation;
import com.sun.codemodel.JMethod;
import com.sun.codemodel.JMods;
import com.sun.codemodel.JOp;
import com.sun.codemodel.JStatement;
import com.sun.codemodel.JSwitch;
import com.sun.codemodel.JType;
import com.sun.codemodel.JTypeVar;
import com.sun.codemodel.JVar;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.JavaExpression;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.model.PMMLException;
import org.jpmml.model.UnsupportedAttributeException;
import org.jpmml.translator.tree.TreeModelTranslator;
public class TranslationContext {
private PMML pmml = null;
private JCodeModel codeModel = null;
private List issues = new ArrayList<>();
private Deque owners = new ArrayDeque<>();
private Deque scopes = new ArrayDeque<>();
private ArraySetManager xmlNameManager = null;
private List> valueManagers = new ArrayList<>();
private Map translations = new IdentityHashMap<>();
private Set activeFieldNames = new LinkedHashSet<>();
public TranslationContext(PMML pmml, JCodeModel codeModel){
setPMML(pmml);
setCodeModel(codeModel);
}
public JClass ref(Class> type){
JCodeModel codeModel = getCodeModel();
return codeModel.ref(type);
}
public JClass genericRef(Class> type, Object... typeArgs){
List safeTypeArgs = Arrays.stream(typeArgs)
.map(typeArg -> {
if(typeArg instanceof Class){
return ref((Class>)typeArg);
}
return (JClass)typeArg;
})
.collect(Collectors.toList());
return genericRef(type, safeTypeArgs);
}
public JClass genericRef(Class> type, List typeArgs){
return ref(type).narrow(typeArgs);
}
public JType _ref(Class> type){
JCodeModel codeModel = getCodeModel();
return codeModel._ref(type);
}
public JClass wildcard(){
JCodeModel codeModel = getCodeModel();
return codeModel.wildcard();
}
public JDefinedClass getOwner(){
return this.owners.getFirst();
}
public JDefinedClass getOwner(Class> clazz){
Deque owners = getOwners();
JClass superClazz = ref(clazz);
for(Iterator it = owners.iterator(); it.hasNext(); ){
JDefinedClass owner = it.next();
if(superClazz.isAssignableFrom(owner)){
return owner;
}
}
throw new IllegalArgumentException();
}
public Deque getOwners(){
return this.owners;
}
public void pushOwner(JDefinedClass owner){
if(isSubclass(PMML.class, owner)){
this.xmlNameManager = new ArraySetManager(ref(QName.class), "xmlNames"){
{
initArrayVar(owner);
}
@Override
public JExpression createExpression(QName name){
return _new(QName.class, name.getNamespaceURI(), name.getLocalPart(), name.getPrefix());
}
};
}
this.owners.addFirst(owner);
}
public void popOwner(){
JDefinedClass owner = this.owners.peekFirst();
if(isSubclass(PMML.class, owner)){
PMML pmml = getPMML();
JBinaryFileInitializer resourceInitializer = new JBinaryFileInitializer(IdentifierUtil.create(PMML.class.getSimpleName(), pmml) + ".data", 0, this);
QName[] xmlNames = this.xmlNameManager.getElements()
.toArray(new QName[this.xmlNameManager.size()]);
JInvocation xmlNamesExpr = resourceInitializer.initQNames(xmlNames);
resourceInitializer.assign(this.xmlNameManager.getArrayVar(), xmlNamesExpr);
List> valueManagers = this.valueManagers;
for(ArrayManager> valueManager : valueManagers){
valueManager.initArrayVar(owner);
Object[] values = valueManager.getElements()
.toArray(new Object[valueManager.size()]);
JInvocation valuesExpr = resourceInitializer.initValues(valueManager.getComponentType(), values);
resourceInitializer.assign(valueManager.getArrayVar(), valuesExpr);
}
}
this.owners.removeFirst();
}
public JVar getVariable(String name){
for(Scope scope : this.scopes){
JVar variable = scope.getVariable(name);
if(variable != null){
return variable;
}
}
throw new IllegalArgumentException(name);
}
public ArgumentsRef getArgumentsVariable(){
JDefinedClass owner = getOwner();
JVar variable;
// XXX
if(("Arguments").equals(owner.name())){
try {
Constructor constuctor = JVar.class.getDeclaredConstructor(JMods.class, JType.class, String.class, JExpression.class);
if(!constuctor.isAccessible()){
constuctor.setAccessible(true);
}
variable = constuctor.newInstance(null, owner, "this", null);
} catch(ReflectiveOperationException roe){
throw new RuntimeException(roe);
}
} else
{
variable = getVariable(Scope.VAR_ARGUMENTS);
}
return new ArgumentsRef(variable);
}
public EvaluationContextRef getContextVariable(){
JVar variable = getVariable(Scope.VAR_CONTEXT);
return new EvaluationContextRef(variable);
}
public ValueFactoryRef getValueFactoryVariable(){
JVar variable = getVariable(Scope.VAR_VALUEFACTORY);
return new ValueFactoryRef(variable);
}
public boolean isNonMissing(OperableRef operableRef){
for(Scope scope : this.scopes){
if(scope.isNonMissing(operableRef)){
return true;
}
}
return false;
}
public void markNonMissing(OperableRef operableRef){
Scope scope = ensureOpenScope();
scope.markNonMissing(operableRef);
}
public OperableRef ensureOperable(FieldInfo fieldInfo, Function declareAsVariableFunction){
Field> field = fieldInfo.getField();
Encoder encoder = fieldInfo.getEncoder();
DataType dataType = field.requireDataType();
String variableName;
if(encoder != null){
FieldInfo finalFieldInfo = encoder.follow(fieldInfo);
variableName = finalFieldInfo.getVariableName();
} else
{
variableName = fieldInfo.getVariableName();
}
JExpression expression;
JType type;
try {
JVar variable = getVariable(variableName);
expression = variable;
type = variable.type();
} catch(IllegalArgumentException iae){
JExpression[] initArgExprs = new JExpression[0];
if(encoder instanceof TermFrequencyEncoder){
TermFrequencyEncoder termFrequencyEncoder = (TermFrequencyEncoder)encoder;
TreeModelTranslator.ensureTextIndexFields(fieldInfo, termFrequencyEncoder, this);
} // End if
if(encoder instanceof ArrayEncoder){
ArrayEncoder arrayEncoder = (ArrayEncoder)encoder;
initArgExprs = new JExpression[]{JExpr.lit(arrayEncoder.getIndex())};
}
ArgumentsRef argumentsRef = getArgumentsVariable();
JMethod method = argumentsRef.getMethod(fieldInfo, this);
expression = argumentsRef.invoke(method, initArgExprs);
boolean declareAsVariable = declareAsVariableFunction.apply(method);
if(declareAsVariable){
expression = declare(method.type(), variableName, expression);
}
type = method.type();
}
if(encoder != null){
return encoder.ref(expression);
}
switch(dataType){
case STRING:
return new StringRef(expression);
case INTEGER:
case FLOAT:
case DOUBLE:
case BOOLEAN:
{
if(type.isPrimitive()){
return new PrimitiveRef(expression);
}
return new NumberRef(expression);
}
default:
throw new UnsupportedAttributeException(field, dataType);
}
}
public JTypeVar getTypeVariable(String name){
for(Scope scope : this.scopes){
if(scope instanceof MethodScope){
MethodScope methodScope = (MethodScope)scope;
return methodScope.getTypeVariable(name);
}
}
throw new IllegalArgumentException(name);
}
public JTypeVar getNumberTypeVariable(){
return getTypeVariable(MethodScope.TYPEVAR_NUMBER);
}
public JClass getValueType(){
JTypeVar numberTypeVar = getNumberTypeVariable();
return genericRef(Value.class, numberTypeVar);
}
public JClass getValueMapType(){
JTypeVar numberTypeVar = getNumberTypeVariable();
return genericRef(ValueMap.class, Object.class, numberTypeVar);
}
public JVar declare(Class> type, String name, JExpression init){
return declare(_ref(type), name, init);
}
public JVar declare(JType type, String name, JExpression init){
Scope scope = ensureOpenScope();
return scope.declare(type, name, init);
}
public void add(JStatement statement){
Scope scope = ensureOpenScope();
JBlock block = scope.getBlock();
block.add(statement);
}
public void _comment(String string){
Scope scope = ensureOpenScope();
JBlock block = scope.getBlock();
block.directStatement("// " + string);
}
public void _returnIf(JExpression testExpr, JExpression resultExpr){
Scope scope = ensureOpenScope();
JBlock block = scope.getBlock();
JBlock thenBlock = block._if(testExpr)._then();
thenBlock._return(resultExpr);
}
public void _return(JExpression testExpr, JExpression trueResultExpr, JExpression falseResultExpr){
Scope scope = ensureOpenScope();
try {
JBlock block = scope.getBlock();
block._return(JOp.cond(testExpr, trueResultExpr, falseResultExpr));
} finally {
scope.close();
}
}
public void _return(JExpression resultExpr){
Scope scope = ensureOpenScope();
try {
JBlock block = scope.getBlock();
block._return(resultExpr);
} finally {
scope.close();
}
}
public void _return(JExpression valueExpr, Map, V> resultMap, V defaultResult){
Scope scope = ensureOpenScope();
try {
JBlock block = scope.getBlock();
if(resultMap.size() == 1){
Map.Entry, V> entry = Iterables.getOnlyElement(resultMap.entrySet());
JExpression condExpr = staticInvoke(Objects.class, "equals", valueExpr, PMMLObjectUtil.createExpression(entry.getKey(), this));
block._return(JOp.cond(condExpr, PMMLObjectUtil.createExpression(entry.getValue(), this), PMMLObjectUtil.createExpression(defaultResult, this)));
} else
if((resultMap.size() == 2) && (resultMap.containsKey(Boolean.TRUE) && resultMap.containsKey(Boolean.FALSE))){
V trueValue = resultMap.get(Boolean.TRUE);
V falseValue = resultMap.get(Boolean.FALSE);
JClass booleanClass = ref(Boolean.class);
JBlock trueBlock = block._if(staticInvoke(Objects.class, "equals", valueExpr, booleanClass.staticRef("TRUE")))._then();
trueBlock._return(PMMLObjectUtil.createExpression(trueValue, this));
JBlock falseBlock = block._if(staticInvoke(Objects.class, "equals", valueExpr, booleanClass.staticRef("FALSE")))._then();
falseBlock._return(PMMLObjectUtil.createExpression(falseValue, this));
block._return(PMMLObjectUtil.createExpression(defaultResult, this));
} else
if((resultMap.size() > 64) && JBinaryFileInitializer.isExternalizable(resultMap.keySet())){
JBinaryFileInitializer resourceInitializer = new JBinaryFileInitializer(IdentifierUtil.create(Map.class.getSimpleName(), Collections.singletonList(resultMap)) + ".data", this);
JFieldVar mapField = resourceInitializer.initNumbersMap("map$" + System.identityHashCode(Collections.singletonList(resultMap)), (Map)resultMap);
JBlock thenBlock = block._if(mapField.invoke("containsKey").arg(valueExpr))._then();
thenBlock._return(mapField.invoke("get").arg(valueExpr));
block._return(PMMLObjectUtil.createExpression(defaultResult, this));
} else
{
boolean stringKeys = true;
Map
© 2015 - 2024 Weber Informatics LLC | Privacy Policy