org.jpmml.rexp.Formula Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-rexp Show documentation
Show all versions of pmml-rexp Show documentation
JPMML R to PMML converter
The newest version!
/*
* Copyright (c) 2016 Villu Ruusmann
*
* This file is part of JPMML-R
*
* JPMML-R is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-R is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-R. If not, see .
*/
package org.jpmml.rexp;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import org.dmg.pmml.Apply;
import org.dmg.pmml.Constant;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.PMMLFunctions;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.HasDerivedName;
import org.jpmml.converter.InteractionFeature;
import org.jpmml.converter.PowerFeature;
import org.jpmml.converter.ValueUtil;
public class Formula {
private RExpEncoder encoder = null;
private Map validNames = new HashMap<>();
private BiMap features = HashBiMap.create();
private List> fields = new ArrayList<>();
public Formula(RExpEncoder encoder){
setEncoder(encoder);
}
public Feature resolveComplexFeature(String name){
RExpEncoder encoder = getEncoder();
List variables = split(name);
if(variables.size() == 1){
return resolveFeature(name);
} else
{
List variableFeatures = new ArrayList<>();
for(String variable : variables){
Feature variableFeature = resolveFeature(variable);
variableFeatures.add(variableFeature);
}
return new InteractionFeature(encoder, name, DataType.DOUBLE, variableFeatures);
}
}
public Feature resolveFeature(String name){
Feature feature = getFeature(name);
if(feature == null){
throw new IllegalArgumentException(name);
}
return feature;
}
public Double getCoefficient(Feature feature, RDoubleVector coefficients){
String name = feature.getName();
if(feature instanceof HasDerivedName){
BiMap inverseFeatures = this.features.inverse();
name = inverseFeatures.get(feature);
}
return coefficients.getElement(name);
}
public Field> getField(int index){
return this.fields.get(index);
}
public void addField(Field> field){
RExpEncoder encoder = getEncoder();
Feature feature = new ContinuousFeature(encoder, field);
if(field instanceof DerivedField){
DerivedField derivedField = (DerivedField)field;
Expression expression = derivedField.requireExpression();
if(expression instanceof Apply){
Apply apply = (Apply)expression;
if(checkApply(apply, PMMLFunctions.POW, FieldRef.class, Constant.class)){
List expressions = apply.getExpressions();
FieldRef fieldRef = (FieldRef)expressions.get(0);
Constant constant = (Constant)expressions.get(1);
try {
String string = ValueUtil.asString(constant.getValue());
int power = Integer.parseInt(string);
feature = new PowerFeature(encoder, fieldRef.requireField(), DataType.DOUBLE, power);
} catch(NumberFormatException nfe){
// Ignored
}
}
}
}
putFeature(field.requireName(), feature);
this.fields.add(field);
}
public void addField(Field> field, List categoryNames, List> categoryValues){
RExpEncoder encoder = getEncoder();
if(categoryNames.size() != categoryValues.size()){
throw new IllegalArgumentException();
}
CategoricalFeature categoricalFeature;
if((field.requireDataType() == DataType.BOOLEAN) && (BooleanFeature.VALUES).equals(categoryValues)){
categoricalFeature = new BooleanFeature(encoder, field);
} else
{
categoricalFeature = new CategoricalFeature(encoder, field, categoryValues);
}
putFeature(field.requireName(), categoricalFeature);
for(int i = 0; i < categoryNames.size(); i++){
String categoryName = categoryNames.get(i);
Object categoryValue = categoryValues.get(i);
BinaryFeature binaryFeature = new BinaryFeature(encoder, field, categoryValue);
putFeature((field.requireName() + categoryName), binaryFeature);
}
this.fields.add(field);
}
private Feature getFeature(String name){
Feature feature = this.features.get(name);
if(feature == null){
if(this.validNames.containsKey(name)){
feature = this.features.get(this.validNames.get(name));
}
}
return feature;
}
private void putFeature(String name, Feature feature){
String validName = RExpUtil.makeName(name);
if(!(name).equals(validName)){
this.validNames.put(validName, name);
}
this.features.put(name, feature);
}
public RExpEncoder getEncoder(){
return this.encoder;
}
private void setEncoder(RExpEncoder encoder){
this.encoder = encoder;
}
/**
* Splits a string by single colon characters (':'), ignoring sequences of two or three colon characters ("::" and ":::").
*/
static
List split(String string){
List result = new ArrayList<>();
int pos = 0;
for(int i = 0; i < string.length(); ){
if(string.charAt(i) == ':'){
int delimBegin = i;
int delimEnd = i;
while((delimEnd + 1) < string.length() && string.charAt(delimEnd + 1) == ':'){
delimEnd++;
}
if(delimBegin == delimEnd){
result.add(string.substring(pos, delimBegin));
pos = (delimEnd + 1);
}
i = (delimEnd + 1);
} else
{
i++;
}
}
if(pos <= string.length()){
result.add(string.substring(pos));
}
return result;
}
static
private boolean checkApply(Apply apply, String function, Class extends Expression>... expressionClazzes){
if((function).equals(apply.requireFunction())){
List expressions = apply.getExpressions();
if(expressionClazzes.length == expressions.size()){
for(int i = 0; i < expressionClazzes.length; i++){
Class extends Expression> expressionClazz = expressionClazzes[i];
Expression expression = expressions.get(i);
if(!(expressionClazz).isInstance(expression)){
return false;
}
}
return true;
}
}
return false;
}
}