All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.converter.PMMLEncoder Maven / Gradle / Ivy

There is a newer version: 1.5.10
Show newest version
/*
 * Copyright (c) 2016 Villu Ruusmann
 *
 * This file is part of JPMML-Converter
 *
 * JPMML-Converter is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Converter is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Converter.  If not, see .
 */
package org.jpmml.converter;

import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;

import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Header;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.TransformationDictionary;

public class PMMLEncoder {

	private Map dataFields = new LinkedHashMap<>();

	private Map derivedFields = new LinkedHashMap<>();

	private Map defineFunctions = new LinkedHashMap<>();


	public PMML encodePMML(){

		if(!Collections.disjoint(this.dataFields.keySet(), this.derivedFields.keySet())){
			throw new IllegalStateException();
		}

		List dataFields = new ArrayList<>(this.dataFields.values());

		DataDictionary dataDictionary = new DataDictionary();

		if(dataFields.size() > 0){
			(dataDictionary.getDataFields()).addAll(dataFields);
		}

		List derivedFields = new ArrayList<>(this.derivedFields.values());
		List defineFunctions = new ArrayList<>(this.defineFunctions.values());

		TransformationDictionary transformationDictionary = null;

		if(derivedFields.size() > 0 || defineFunctions.size() > 0){
			transformationDictionary = new TransformationDictionary();

			if(derivedFields.size() > 0){
				(transformationDictionary.getDerivedFields()).addAll(derivedFields);
			} // End if

			if(defineFunctions.size() > 0){
				(transformationDictionary.getDefineFunctions()).addAll(defineFunctions);
			}
		}

		Header header = encodeHeader();

		PMML pmml = new PMML("4.3", header, dataDictionary)
			.setTransformationDictionary(transformationDictionary);

		return pmml;
	}

	public Header encodeHeader(){
		return PMMLUtil.createHeader(getClass());
	}

	public DataField getDataField(FieldName name){
		return this.dataFields.get(name);
	}

	public void addDataField(DataField dataField){
		FieldName name = checkName(dataField);

		this.dataFields.put(name, dataField);
	}

	public DataField createDataField(FieldName name, OpType opType, DataType dataType){
		return createDataField(name, opType, dataType, null);
	}

	public DataField createDataField(FieldName name, OpType opType, DataType dataType, List values){
		DataField dataField = new DataField(name, opType, dataType);

		if(values != null && values.size() > 0){
			PMMLUtil.addValues(dataField, values);
		}

		addDataField(dataField);

		return dataField;
	}

	public DataField removeDataField(FieldName name){
		DataField dataField = this.dataFields.remove(name);

		if(dataField == null){
			throw new IllegalArgumentException("Field " + name.getValue() + " is undefined");
		}

		return dataField;
	}

	public DerivedField ensureDerivedField(FieldName name, OpType opType, DataType dataType, Supplier expressionSupplier){
		DerivedField derivedField = getDerivedField(name);

		if(derivedField == null){
			Expression expression = expressionSupplier.get();

			derivedField = createDerivedField(name, opType, dataType, expression);
		}

		return derivedField;
	}

	public DerivedField getDerivedField(FieldName name){
		return this.derivedFields.get(name);
	}

	public void addDerivedField(DerivedField derivedField){
		FieldName name = checkName(derivedField);

		this.derivedFields.put(name, derivedField);
	}

	public DerivedField createDerivedField(FieldName name, OpType opType, DataType dataType, Expression expression){
		DerivedField derivedField = new DerivedField(opType, dataType)
			.setName(name)
			.setExpression(expression);

		addDerivedField(derivedField);

		return derivedField;
	}

	public DerivedField removeDerivedField(FieldName name){
		DerivedField derivedField = this.derivedFields.remove(name);

		if(derivedField == null){
			throw new IllegalArgumentException("Field " +name.getValue() + " is undefined");
		}

		return derivedField;
	}

	public Field getField(FieldName name){
		DataField dataField = getDataField(name);
		DerivedField derivedField = getDerivedField(name);

		if(dataField != null && derivedField != null){
			throw new IllegalStateException();
		} // End if

		if(dataField != null && derivedField == null){
			return dataField;
		} else

		if(dataField == null && derivedField != null){
			return derivedField;
		}

		throw new IllegalArgumentException("Field " + name.getValue() + " is undefined");
	}

	public Field toContinuous(FieldName name){
		Field field = getField(name);

		DataType dataType = field.getDataType();
		switch(dataType){
			case INTEGER:
			case FLOAT:
			case DOUBLE:
				break;
			default:
				throw new IllegalArgumentException("Field " + name.getValue() + " has data type " + dataType);
		}

		field.setOpType(OpType.CONTINUOUS);

		return field;
	}

	public Field toCategorical(FieldName name, List values){
		Field field = getField(name);

		dataField:
		if(field instanceof DataField){
			DataField dataField = (DataField)field;

			List existingValues = PMMLUtil.getValues(dataField);
			if(existingValues != null && existingValues.size() > 0){

				if((existingValues).equals(values)){
					break dataField;
				}

				throw new IllegalArgumentException("Field " + name.getValue() + " has valid values " + existingValues);
			}

			PMMLUtil.addValues(dataField, values);
		}

		field.setOpType(OpType.CATEGORICAL);

		return field;
	}

	public DefineFunction getDefineFunction(String name){
		return this.defineFunctions.get(name);
	}

	public void addDefineFunction(DefineFunction defineFunction){
		String name = defineFunction.getName();

		if(name == null){
			throw new NullPointerException();
		} // End if

		if(this.defineFunctions.containsKey(name)){
			throw new IllegalArgumentException(name);
		}

		this.defineFunctions.put(name, defineFunction);
	}

	public Map getDataFields(){
		return this.dataFields;
	}

	public Map getDerivedFields(){
		return this.derivedFields;
	}

	public Map getDefineFunctions(){
		return this.defineFunctions;
	}

	private FieldName checkName(Field field){
		FieldName name = field.getName();

		if(name == null){
			throw new IllegalArgumentException();
		} // End if

		if(this.dataFields.containsKey(name) || this.derivedFields.containsKey(name)){
			throw new IllegalArgumentException("Field " + name.getValue() + " is already defined");
		}

		return name;
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy