All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.converter.PMMLUtil Maven / Gradle / Ivy

There is a newer version: 1.5.10
Show newest version
/*
 * Copyright (c) 2014 Villu Ruusmann
 *
 * This file is part of JPMML-Converter
 *
 * JPMML-Converter is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Converter is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Converter.  If not, see .
 */
package org.jpmml.converter;

import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Date;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.TimeZone;
import java.util.function.Function;

import javax.xml.bind.JAXBElement;
import javax.xml.namespace.QName;

import org.dmg.pmml.Application;
import org.dmg.pmml.Apply;
import org.dmg.pmml.Array;
import org.dmg.pmml.ComplexArray;
import org.dmg.pmml.Constant;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldColumnPair;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Header;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.Interval;
import org.dmg.pmml.MapValues;
import org.dmg.pmml.RealSparseArray;
import org.dmg.pmml.Row;
import org.dmg.pmml.Timestamp;
import org.dmg.pmml.Value;
import org.jpmml.model.inlinetable.InputCell;
import org.jpmml.model.inlinetable.OutputCell;

public class PMMLUtil {

	private PMMLUtil(){
	}

	static
	public Header createHeader(Class clazz){
		Package _package = clazz.getPackage();

		return createHeader(_package.getImplementationTitle(), _package.getImplementationVersion());
	}

	static
	public Header createHeader(String name, String version){
		Application application = new Application()
			.setName(name)
			.setVersion(version);

		return createHeader(application);
	}

	static
	public Header createHeader(Application application){
		Date now = new Date();

		// XML Schema "dateTime" data format (corresponds roughly to ISO 8601)
		DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss'Z'");
		dateFormat.setTimeZone(PMMLUtil.UTC);

		Timestamp timestamp = new Timestamp()
			.addContent(dateFormat.format(now));

		Header header = new Header()
			.setApplication(application)
			.setTimestamp(timestamp);

		return header;
	}

	static
	public List getValues(DataField dataField){
		return getValues(dataField, null);
	}

	static
	public List getValues(DataField dataField, Value.Property property){
		List result = new ArrayList<>();

		if(property == null){
			property = Value.Property.VALID;
		}

		List pmmlValues = dataField.getValues();
		for(Value pmmlValue : pmmlValues){

			if((property).equals(pmmlValue.getProperty())){
				result.add(pmmlValue.getValue());
			}
		}

		return result;
	}

	static
	public void addValues(DataField dataField, List values){
		addValues(dataField, values, null);
	}

	static
	public void addValues(DataField dataField, List values, Value.Property property){

		if((Value.Property.VALID).equals(property)){
			property = null;
		}

		List pmmlValues = dataField.getValues();
		for(Object value : values){
			Value pmmlValue = new Value(value)
				.setProperty(property);

			pmmlValues.add(pmmlValue);
		}
	}

	static
	public void addIntervals(DataField dataField, List intervals){
		(dataField.getIntervals()).addAll(intervals);
	}

	static
	public Apply createApply(String function, Expression... expressions){
		Apply apply = new Apply(function)
			.addExpressions(expressions);

		return apply;
	}

	static
	public Constant createConstant(Number value){
		return createConstant(value, TypeUtil.getDataType(value));
	}

	static
	public Constant createConstant(Object value, DataType dataType){
		Constant constant = new Constant(value)
			.setDataType(dataType);

		return constant;
	}

	static
	public MapValues createMapValues(FieldName name, Map mapping){
		List inputValues = new ArrayList<>();
		List outputValues = new ArrayList<>();

		Collection> entries = mapping.entrySet();
		for(Map.Entry entry : entries){
			inputValues.add(entry.getKey());
			outputValues.add(entry.getValue());
		}

		return createMapValues(name, inputValues, outputValues);
	}

	static
	public MapValues createMapValues(FieldName name, List inputValues, List outputValues){
		String inputColumn = "data:input";
		String outputColumn = "data:output";

		Map> data = new LinkedHashMap<>();
		data.put(inputColumn, inputValues);
		data.put(outputColumn, outputValues);

		MapValues mapValues = new MapValues()
			.addFieldColumnPairs(new FieldColumnPair(name, inputColumn))
			.setOutputColumn(outputColumn)
			.setInlineTable(PMMLUtil.createInlineTable(data));

		return mapValues;
	}

	static
	public Array createStringArray(List values){
		Array array = new ComplexArray()
			.setType(Array.Type.STRING)
			.setValue(values);

		return array;
	}

	static
	public Array createIntArray(List values){
		Array array = new ComplexArray()
			.setType(Array.Type.INT)
			.setValue(values);

		return array;
	}

	static
	public Array createRealArray(List values){
		Array array = new ComplexArray()
			.setType(Array.Type.REAL)
			.setValue(values);

		return array;
	}

	static
	public RealSparseArray createRealSparseArray(List values, Double defaultValue){
		RealSparseArray sparseArray = new RealSparseArray()
			.setN(values.size())
			.setDefaultValue(defaultValue);

		List indices = sparseArray.getIndices();
		List entries = sparseArray.getEntries();

		int index = 1;

		for(Number value : values){

			if(!ValueUtil.equals(value, defaultValue)){
				indices.add(index);
				entries.add(ValueUtil.asDouble(value));
			}

			index++;
		}

		return sparseArray;
	}

	static
	public InlineTable createInlineTable(Map> data){
		return createInlineTable(Function.identity(), data);
	}

	static
	public  InlineTable createInlineTable(Function function, Map> data){
		int rows = 0;

		Map columns = new LinkedHashMap<>();

		{
			Collection>> entries = data.entrySet();
			for(Map.Entry> entry : entries){
				K column = entry.getKey();
				List columnData = entry.getValue();

				if(rows == 0){
					rows = columnData.size();
				} else

				{
					if(rows != columnData.size()){
						throw new IllegalArgumentException();
					}
				}

				QName columnName;

				String tagName = function.apply(column);
				if(tagName.startsWith("data:")){
					columnName = new QName("http://jpmml.org/jpmml-model/InlineTable", tagName.substring("data:".length()), "data");
				} else

				{
					if(tagName.indexOf(':') > -1){
						throw new IllegalArgumentException(tagName);
					}

					columnName = new QName("http://www.dmg.org/PMML-4_3", tagName);
				}

				columns.put(column, columnName);
			}
		}

		QName inputColumnName = InputCell.QNAME;
		QName outputColumnName = OutputCell.QNAME;

		InlineTable inlineTable = new InlineTable();

		for(int i = 0; i < rows; i++){
			Row row = new Row();

			Collection> entries = columns.entrySet();
			for(Map.Entry entry : entries){
				List columnData = data.get(entry.getKey());

				Object value = columnData.get(i);
				if(value == null){
					continue;
				}

				QName columName = entry.getValue();

				Object cell;

				if((inputColumnName).equals(columName)){
					cell = new InputCell(value);
				} else

				if((outputColumnName).equals(columName)){
					cell = new OutputCell(value);
				} else

				{
					cell = new JAXBElement<>(columName, String.class, org.jpmml.model.ValueUtil.toString(value));
				}

				row.addContent(cell);
			}

			inlineTable.addRows(row);
		}

		return inlineTable;
	}

	private static final TimeZone UTC = TimeZone.getTimeZone("UTC");
}