All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.rexp.CaretEnsembleConverter Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (c) 2018 Villu Ruusmann
 *
 * This file is part of JPMML-R
 *
 * JPMML-R is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-R is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-R.  If not, see .
 */
package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.mining.MiningModelUtil;

public class CaretEnsembleConverter extends Converter {

	public CaretEnsembleConverter(RGenericVector caretEnsemble){
		super(caretEnsemble);
	}

	@Override
	public PMML encodePMML(RExpEncoder encoder){
		RGenericVector caretEnsemble = getObject();

		RGenericVector models = caretEnsemble.getGenericElement("models");
		RGenericVector ensModel = caretEnsemble.getGenericElement("ens_model");

		RStringVector modelNames = models.names();

		List segmentationModels = new ArrayList<>();

		Function segmentSchemaFunction = new Function(){

			@Override
			public Schema apply(Schema schema){
				Label label = schema.getLabel();

				if(label instanceof ContinuousLabel){
					return schema.toAnonymousSchema();
				} else

				// XXX: Ideally, the categorical target field should also be anonymized
				if(label instanceof CategoricalLabel){
					return schema;
				} else

				{
					throw new IllegalArgumentException();
				}
			}
		};

		for(int i = 0; i < models.size(); i++){
			RGenericVector model = models.getGenericValue(i);

			Conversion conversion = encodeTrainModel(model, segmentSchemaFunction);

			RExpEncoder segmentEncoder = conversion.getEncoder();

			encoder.addFields(segmentEncoder);

			Schema segmentSchema = conversion.getSchema();
			Model segmentModel = conversion.getModel();

			String name = modelNames.getValue(i);

			OutputField outputField;

			MiningFunction miningFunction = segmentModel.requireMiningFunction();
			switch(miningFunction){
				case REGRESSION:
					{
						outputField = ModelUtil.createPredictedField(name, OpType.CONTINUOUS, DataType.DOUBLE)
							.setFinalResult(Boolean.FALSE);
					}
					break;
				case CLASSIFICATION:
					{
						CategoricalLabel categoricalLabel = (CategoricalLabel)segmentSchema.getLabel();

						SchemaUtil.checkSize(2, categoricalLabel);

						outputField = ModelUtil.createProbabilityField(name, DataType.DOUBLE, categoricalLabel.getValue(1))
							.setFinalResult(Boolean.FALSE);
					}
					break;
				default:
					throw new IllegalArgumentException();
			}

			Output output = new Output()
				.addOutputFields(outputField);

			segmentModel.setOutput(output);

			segmentationModels.add(segmentModel);
		}

		Conversion conversion = encodeTrainModel(ensModel, null);

		Model model = conversion.getModel();

		segmentationModels.add(model);

		MiningModel miningModel = MiningModelUtil.createModelChain(segmentationModels, Segmentation.MissingPredictionTreatment.CONTINUE);

		PMML pmml = encoder.encodePMML(miningModel);

		return pmml;
	}

	private Conversion encodeTrainModel(RGenericVector train, Function schemaFunction){
		RExp finalModel = train.getElement("finalModel");

		ModelConverter converter = (ModelConverter)newConverter(finalModel);

		RExpEncoder encoder = new RExpEncoder();

		converter.encodeSchema(encoder);

		Schema schema = encoder.createSchema();

		if(schemaFunction != null){
			schema = schemaFunction.apply(schema);
		}

		Model model = converter.encode(schema);

		return new Conversion(encoder, schema, model);
	}

	static
	private class Conversion {

		private RExpEncoder encoder = null;

		private Schema schema = null;

		private Model model = null;


		private Conversion(RExpEncoder encoder, Schema schema, Model model){
			setEncoder(encoder);
			setSchema(schema);
			setModel(model);
		}

		public RExpEncoder getEncoder(){
			return this.encoder;
		}

		private void setEncoder(RExpEncoder encoder){
			this.encoder = encoder;
		}

		public Schema getSchema(){
			return this.schema;
		}

		private void setSchema(Schema schema){
			this.schema = schema;
		}

		public Model getModel(){
			return this.model;
		}

		private void setModel(Model model){
			this.model = model;
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy