All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.rexp.ModelConverter Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2016 Villu Ruusmann
 *
 * This file is part of JPMML-R
 *
 * JPMML-R is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-R is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-R.  If not, see .
 */
package org.jpmml.rexp;

import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.dmg.pmml.VerificationField;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureImportanceMap;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.ScalarLabel;
import org.jpmml.converter.Schema;

abstract
public class ModelConverter extends Converter {

	public ModelConverter(R object){
		super(object);
	}

	abstract
	public void encodeSchema(RExpEncoder encoder);

	abstract
	public Model encodeModel(Schema schema);

	public Model encode(Schema schema){
		Model model = encodeModel(schema);

		if(this instanceof HasFeatureImportances){
			HasFeatureImportances hasFeatureImportances = (HasFeatureImportances)this;

			FeatureImportanceMap featureImportances = hasFeatureImportances.getFeatureImportances(schema);
			if(featureImportances != null && !featureImportances.isEmpty()){
				ModelEncoder encoder = schema.getEncoder();

				Collection> entries = featureImportances.entrySet();
				for(Map.Entry entry : entries){
					encoder.addFeatureImportance(model, entry.getKey(), entry.getValue());
				}
			}
		}

		return model;
	}

	@Override
	public PMML encodePMML(RExpEncoder encoder){
		RExp object = getObject();

		RGenericVector verification = null;

		if(object instanceof S4Object){
			S4Object model = (S4Object)object;

			verification = model.getGenericAttribute("verification", false);
		} else

		if(object instanceof RGenericVector){
			RGenericVector model = (RGenericVector)object;

			verification = model.getGenericElement("verification", false);
		}

		encodeSchema(encoder);

		Schema schema = encoder.createSchema();

		Model model = encode(schema);

		verification:
		if(verification != null){
			RDoubleVector precision = verification.getDoubleElement("precision");
			RDoubleVector zeroThreshold = verification.getDoubleElement("zeroThreshold");

			VerificationMap data = new VerificationMap(precision.asScalar(), zeroThreshold.asScalar());

			RGenericVector activeValues = verification.getGenericElement("active_values");
			RGenericVector targetValues = verification.getGenericElement("target_values", false);
			RGenericVector outputValues = verification.getGenericElement("output_values", false);

			if(activeValues != null){
				data.putInputData(encodeActiveValues(activeValues));
			} // End if

			if(targetValues != null && outputValues == null){
				ScalarLabel scalarLabel = (ScalarLabel)schema.getLabel();

				String name = scalarLabel.getName();

				Collection verificationFields = data.keySet();
				for(Iterator verificationFieldIt = verificationFields.iterator(); verificationFieldIt.hasNext(); ){
					VerificationField verificationField = verificationFieldIt.next();

					if((verificationField.requireField()).equals(name)){
						verificationFieldIt.remove();
					}
				}

				data.putResultData(encodeTargetValues(targetValues, scalarLabel));
			} else

			if(outputValues != null){
				data.putResultData(encodeOutputValues(outputValues));
			} else

			{
				break verification;
			}

			model.setModelVerification(ModelUtil.createModelVerification(data));
		}

		PMML pmml = encoder.encodePMML(model);

		return pmml;
	}

	protected Map> encodeActiveValues(RGenericVector dataFrame){
		return encodeVerificationData(dataFrame);
	}

	protected Map> encodeTargetValues(RGenericVector dataFrame, ScalarLabel scalarLabel){
		List columns = dataFrame.getValues();
		String name = scalarLabel.getName();

		return encodeVerificationData(columns, Collections.singletonList(name));
	}

	protected Map> encodeOutputValues(RGenericVector dataFrame){
		return encodeVerificationData(dataFrame);
	}

	static
	protected Map> encodeVerificationData(RGenericVector dataFrame){
		List columns = dataFrame.getValues();
		RStringVector columnNames = dataFrame.names();

		return encodeVerificationData(columns, columnNames.getDequotedValues());
	}

	static
	protected Map> encodeVerificationData(List columns, List names){
		Map> result = new LinkedHashMap<>();

		for(int i = 0; i < columns.size(); i++){
			String name = names.get(i);
			RVector column = (RVector)columns.get(i);

			List values;

			if(column instanceof RDoubleVector){
				Function function = new Function(){

					@Override
					public Double apply(Double value){

						if(value.isNaN()){
							return null;
						}

						return value;
					}
				};

				values = Lists.transform((List)column.getValues(), function);
			} else

			if(column instanceof RFactorVector){
				RFactorVector factor = (RFactorVector)column;

				values = factor.getFactorValues();
			} else

			{
				values = column.getValues();
			}

			VerificationField verificationField = ModelUtil.createVerificationField(name);

			result.put(verificationField, values);
		}

		return result;
	}

	static
	private class VerificationMap extends LinkedHashMap> {

		private Double precision = null;

		private Double zeroThreshold = null;


		public VerificationMap(Double precision, Double zeroThreshold){
			setPrecision(precision);
			setZeroThreshold(zeroThreshold);
		}

		public void putInputData(Map> map){
			putAll(map);
		}

		public void putResultData(Map> map){
			Double precision = getPrecision();
			Double zeroThreshold = getZeroThreshold();

			Collection verificationFields = map.keySet();
			for(VerificationField verificationField : verificationFields){
				verificationField
					.setPrecision(precision)
					.setZeroThreshold(zeroThreshold);
			}

			putAll(map);
		}

		public double getPrecision(){
			return this.precision;
		}

		private void setPrecision(double precision){
			this.precision = precision;
		}

		public double getZeroThreshold(){
			return this.zeroThreshold;
		}

		private void setZeroThreshold(double zeroThreshold){
			this.zeroThreshold = zeroThreshold;
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy