All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.rexp.MixtureModelConverter Maven / Gradle / Ivy

There is a newer version: 1.6.6
Show newest version
/*
 * Copyright (c) 2024 Villu Ruusmann
 *
 * This file is part of JPMML-R
 *
 * JPMML-R is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-R is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-R.  If not, see .
 */
package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.regression.RegressionModelUtil;

abstract
public class MixtureModelConverter extends Converter {

	private ContinuousLabel label = null;


	public MixtureModelConverter(RGenericVector object){
		super(object);
	}

	abstract
	protected Model encodeZeroComponent(List features, List coefficients, Double intercept, Schema schema);

	abstract
	protected Model encodeCountComponent(List features, List coefficients, Double intercept, Schema schema);

	protected Model encodeComponent(String name, RExpEncoder encoder){
		RGenericVector object = getObject();

		RDoubleVector coefficients = object.getGenericElement("coefficients").getDoubleElement(name);
		RExp terms = object.getGenericElement("terms").getElement(name);
		RGenericVector model = object.getGenericElement("model");

		RStringVector coefficientNames = coefficients.names();

		FormulaContext context = new ModelFrameFormulaContext(model);

		Formula formula = FormulaUtil.createFormula(terms, context, encoder);

		switch(name){
			case MixtureModelConverter.NAME_COUNT:
				FormulaUtil.setLabel(formula, terms, null, encoder);

				ContinuousLabel continuousLabel = (ContinuousLabel)encoder.getLabel();

				// XXX
				DataField dataField = (DataField)encoder.getField(continuousLabel.getName());
				dataField.setDataType(DataType.DOUBLE);

				setLabel(new ContinuousLabel(dataField));

				break;
			case MixtureModelConverter.NAME_ZERO:
				break;
			default:
				throw new IllegalArgumentException(name);
		}

		encoder.setLabel(new ContinuousLabel(DataType.DOUBLE));

		List features = encoder.getFeatures();
		if(!features.isEmpty()){
			features.clear();
		}

		List names = FormulaUtil.removeSpecialSymbol(coefficientNames.getDequotedValues(), "(Intercept)");

		FormulaUtil.addFeatures(formula, names, true, encoder);

		features = encoder.getFeatures();

		Schema schema = encoder.createSchema();

		Double intercept = coefficients.getElement("(Intercept)", false);

		SchemaUtil.checkSize(coefficients.size() - (intercept != null ? 1 : 0), features);

		List featureCoefficients = new ArrayList<>();

		for(Feature feature : features){
			Double coefficient = formula.getCoefficient(feature, coefficients);

			featureCoefficients.add(coefficient);
		}

		switch(name){
			case MixtureModelConverter.NAME_ZERO:
				return encodeZeroComponent(features, featureCoefficients, intercept, schema);
			case MixtureModelConverter.NAME_COUNT:
				return encodeCountComponent(features, featureCoefficients, intercept, schema);
			default:
				throw new IllegalArgumentException(name);
		}
	}

	protected Model encodeTarget(DerivedField derivedField, Map outputFields, RExpEncoder encoder){
		ContinuousLabel label = getLabel();

		Feature feature = new ContinuousFeature(encoder, derivedField);

		Schema targetSchema = new Schema(encoder, label, Collections.emptyList());

		Output output = new Output();

		Collection> entries = outputFields.entrySet();
		for(Map.Entry entry : entries){
			String name = entry.getKey();
			OutputField outputField = entry.getValue();

			OutputField targetOutputField = new OutputField(name, outputField.requireOpType(), outputField.requireDataType())
				.setResultFeature(ResultFeature.TRANSFORMED_VALUE)
				.setExpression(new FieldRef(outputField));

			output.addOutputFields(targetOutputField);
		}

		RegressionModel regressionModel = RegressionModelUtil.createRegression(Collections.singletonList(feature), Collections.singletonList(1d), null, RegressionModel.NormalizationMethod.NONE, targetSchema)
			.setOutput(output);

		return regressionModel;
	}

	protected ContinuousLabel getLabel(){
		return this.label;
	}

	private void setLabel(ContinuousLabel label){
		this.label = label;
	}

	protected static final String NAME_COUNT = "count";
	protected static final String NAME_FULL = "full";
	protected static final String NAME_ZERO = "zero";
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy