All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.converter.ModelEncoder Maven / Gradle / Ivy

There is a newer version: 1.5.10
Show newest version
/*
 * Copyright (c) 2017 Villu Ruusmann
 *
 * This file is part of JPMML-Converter
 *
 * JPMML-Converter is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Converter is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Converter.  If not, see .
 */
package org.jpmml.converter;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.ModelStats;
import org.dmg.pmml.PMML;
import org.dmg.pmml.UnivariateStats;
import org.jpmml.model.VisitorBattery;

public class ModelEncoder extends PMMLEncoder {

	private Map> decorators = new LinkedHashMap<>();

	private Map univariateStats = new LinkedHashMap<>();


	public PMML encodePMML(Model model){
		PMML pmml = encodePMML();

		pmml.addModels(model);

		VisitorBattery visitorBattery = new CleanerBattery();
		if(visitorBattery.size() > 0){
			visitorBattery.applyTo(pmml);
		}

		MiningSchema miningSchema = model.getMiningSchema();

		List miningFields = miningSchema.getMiningFields();
		for(MiningField miningField : miningFields){
			FieldName name = miningField.getName();

			List decorators = getDecorators(name);
			if(decorators == null){
				continue;
			}

			DataField dataField = getDataField(name);
			if(dataField == null){
				throw new IllegalArgumentException();
			}

			for(Decorator decorator : decorators){
				decorator.decorate(dataField, miningField);
			}
		}

		DataDictionary dataDictionary = pmml.getDataDictionary();

		List dataFields = dataDictionary.getDataFields();
		for(DataField dataField : dataFields){
			UnivariateStats univariateStats = getUnivariateStats(dataField.getName());

			if(univariateStats == null){
				continue;
			}

			ModelStats modelStats = model.getModelStats();
			if(modelStats == null){
				modelStats = new ModelStats();

				model.setModelStats(modelStats);
			}

			modelStats.addUnivariateStats(univariateStats);
		}

		return pmml;
	}

	public List getDecorators(FieldName name){
		return this.decorators.get(name);
	}

	public void addDecorator(FieldName name, Decorator decorator){
		List decorators = this.decorators.get(name);

		if(decorators == null){
			decorators = new ArrayList<>();

			this.decorators.put(name, decorators);
		}

		decorators.add(decorator);
	}

	public UnivariateStats getUnivariateStats(FieldName name){
		return this.univariateStats.get(name);
	}

	public void putUnivariateStats(UnivariateStats univariateStats){
		putUnivariateStats(univariateStats.getField(), univariateStats);
	}

	public void putUnivariateStats(FieldName name, UnivariateStats univariateStats){
		this.univariateStats.put(name, univariateStats);
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy