org.jpmml.converter.ModelEncoder Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of pmml-converter Show documentation
Show all versions of pmml-converter Show documentation
JPMML class model converters
The newest version!
/*
* Copyright (c) 2017 Villu Ruusmann
*
* This file is part of JPMML-Converter
*
* JPMML-Converter is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-Converter is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-Converter. If not, see .
*/
package org.jpmml.converter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ListMultimap;
import org.dmg.pmml.Field;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.ModelStats;
import org.dmg.pmml.NamespacePrefixes;
import org.dmg.pmml.PMML;
import org.dmg.pmml.UnivariateStats;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.visitors.FeatureExpander;
import org.jpmml.converter.visitors.ModelCleanerBattery;
import org.jpmml.converter.visitors.PMMLCleanerBattery;
import org.jpmml.model.visitors.VisitorBattery;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class ModelEncoder extends PMMLEncoder {
private List transformers = new ArrayList<>();
private Map> decorators = new LinkedHashMap<>();
private Map> featureImportances = new LinkedHashMap<>();
private Map> univariateStats = new LinkedHashMap<>();
public PMML encodePMML(Model model){
PMML pmml = encodePMML();
model = encodeModel(model);
if(model != null){
pmml.addModels(model);
VisitorBattery modelCleanerBattery = new ModelCleanerBattery();
modelCleanerBattery.applyTo(pmml);
encodeDecorators(pmml);
encodeFeatureImportances(pmml);
encodeUnivariateStats(pmml);
}
VisitorBattery pmmlCleanerBattery = new PMMLCleanerBattery();
pmmlCleanerBattery.applyTo(pmml);
return pmml;
}
public Model encodeModel(Model model){
List transformers = getTransformers();
if(model != null){
transferContent(null, model);
} // End if
if(!transformers.isEmpty()){
List models = new ArrayList<>(transformers);
if(model != null){
models.add(model);
}
MiningModel miningModel = MiningModelUtil.createModelChain(models, Segmentation.MissingPredictionTreatment.CONTINUE);
transferUnivariateStats(model, miningModel);
return miningModel;
}
return model;
}
public List getTransformers(){
return this.transformers;
}
public void addTransformer(Model transformer){
this.transformers.add(transformer);
}
public Map> getDecorators(){
return this.decorators;
}
public void addDecorator(Field> field, Decorator decorator){
addDecorator(null, field, decorator);
}
public void addDecorator(Model model, Field> field, Decorator decorator){
Map> modelDecorators = getDecorators();
ListMultimap decorators = modelDecorators.get(model);
if(decorators == null){
decorators = ArrayListMultimap.create();
modelDecorators.put(model, decorators);
}
decorators.put(field.requireName(), decorator);
}
public Map> getFeatureImportances(){
return this.featureImportances;
}
public void addFeatureImportance(Feature feature, Number importance){
addFeatureImportance(null, feature, importance);
}
public void addFeatureImportance(Model model, Feature feature, Number importance){
Map> modelFeatureImportances = getFeatureImportances();
ListMultimap featureImportances = modelFeatureImportances.get(model);
if(featureImportances == null){
featureImportances = ArrayListMultimap.create();
modelFeatureImportances.put(model, featureImportances);
}
featureImportances.put(feature, importance);
}
public Map> getUnivariateStats(){
return this.univariateStats;
}
public void addUnivariateStats(UnivariateStats pmmlUnivariateStats){
addUnivariateStats(null, pmmlUnivariateStats);
}
public void addUnivariateStats(Model model, UnivariateStats pmmlUnivariateStats){
Map> modelUnivariateStats = getUnivariateStats();
List univariateStats = modelUnivariateStats.get(model);
if(univariateStats == null){
univariateStats = new ArrayList<>();
modelUnivariateStats.put(model, univariateStats);
}
univariateStats.add(pmmlUnivariateStats);
}
public void transferContent(Model left, Model right){
transferDecorators(left, right);
transferFeatureImportances(left, right);
transferUnivariateStats(left, right);
}
public void transferDecorators(Model left, Model right){
transferValue(this.decorators, left, right);
}
public void transferFeatureImportances(Model left, Model right){
transferValue(this.featureImportances, left, right);
}
public void transferUnivariateStats(Model left, Model right){
transferValue(this.univariateStats, left, right);
}
private void encodeDecorators(PMML pmml){
Map> modelDecorators = getDecorators();
if(modelDecorators.isEmpty()){
return;
} // End if
if(modelDecorators.containsKey(null)){
throw new IllegalStateException();
}
Collection>> entries = modelDecorators.entrySet();
for(Map.Entry> entry : entries){
Model model = entry.getKey();
ListMultimap decorators = entry.getValue();
MiningSchema miningSchema = model.requireMiningSchema();
if(miningSchema.hasMiningFields()){
List miningFields = miningSchema.getMiningFields();
for(MiningField miningField : miningFields){
String fieldName = miningField.getName();
List fieldDecorators = decorators.get(fieldName);
if(fieldDecorators != null && !fieldDecorators.isEmpty()){
for(Decorator fieldDecorator : fieldDecorators){
fieldDecorator.decorate(miningField);
}
}
}
}
}
}
private void encodeFeatureImportances(PMML pmml){
Map> modelFeatureImportances = getFeatureImportances();
if(modelFeatureImportances.isEmpty()){
return;
} // End if
if(modelFeatureImportances.containsKey(null)){
throw new IllegalStateException();
}
Map> expandableFeatures = (modelFeatureImportances.entrySet()).stream()
.collect(Collectors.toMap(entry -> entry.getKey(), entry -> entry.getValue().keySet().stream()
.map(feature -> feature.getName())
.collect(Collectors.toSet())
));
FeatureExpander featureExpander = new FeatureExpander(expandableFeatures);
featureExpander.applyTo(pmml);
Collection extends Map.Entry>> entries = modelFeatureImportances.entrySet();
for(Map.Entry> entry : entries){
Model model = entry.getKey();
ListMultimap featureImportances = entry.getValue();
MathContext mathContext = model.getMathContext();
Collection> featureImportanceEntries = featureImportances.entries();
Map>> featureFields = featureExpander.getExpandedFeatures(model);
if(featureFields == null){
throw new IllegalArgumentException();
}
ListMultimap fieldImportances = ArrayListMultimap.create();
for(Map.Entry featureImportanceEntry : featureImportanceEntries){
String name = (featureImportanceEntry.getKey()).getName();
Number importance = featureImportanceEntry.getValue();
if(ValueUtil.isZero(importance)){
continue;
}
Set> fields = featureFields.get(name);
if(fields == null){
logger.warn("Unused feature \'" + name + "\' has non-zero importance");
continue;
}
Number fieldImportance = ValueUtil.divide(mathContext, importance, fields.size());
for(Field> field : fields){
fieldImportances.put(field.requireName(), fieldImportance);
}
}
MiningSchema miningSchema = model.requireMiningSchema();
if(miningSchema.hasMiningFields()){
List miningFields = miningSchema.getMiningFields();
for(MiningField miningField : miningFields){
String fieldName = miningField.getName();
MiningField.UsageType usageType = miningField.getUsageType();
switch(usageType){
case ACTIVE:
break;
default:
continue;
}
List fieldImportance = fieldImportances.get(fieldName);
if(fieldImportance != null && !fieldImportance.isEmpty()){
miningField.setImportance(ValueUtil.sum(mathContext, fieldImportance));
}
}
List names = new ArrayList<>();
List importances = new ArrayList<>();
for(Map.Entry featureImportanceEntry : featureImportanceEntries){
names.add(FeatureUtil.getName(featureImportanceEntry.getKey()));
importances.add(featureImportanceEntry.getValue());
}
Map> nativeFeatureImportances = new LinkedHashMap<>();
nativeFeatureImportances.put(NamespacePrefixes.JPMML_INLINETABLE + ":name", names);
nativeFeatureImportances.put(NamespacePrefixes.JPMML_INLINETABLE + ":importance", importances);
List nonZeroImportances = importances.stream()
.filter(importance -> !ValueUtil.isZero(importance))
.collect(Collectors.toList());
InlineTable inlineTable = PMMLUtil.createInlineTable(nativeFeatureImportances)
.addExtensions(PMMLUtil.createExtension("numberOfImportances", String.valueOf(importances.size())))
.addExtensions(PMMLUtil.createExtension("numberOfNonZeroImportances", String.valueOf(nonZeroImportances.size())))
.addExtensions(PMMLUtil.createExtension("sumOfImportances", String.valueOf(ValueUtil.sum(mathContext, importances))));
if(!nonZeroImportances.isEmpty()){
Comparator comparator = new Comparator(){
@Override
public int compare(Number left, Number right){
return Double.compare(left.doubleValue(), right.doubleValue());
}
};
inlineTable
.addExtensions(PMMLUtil.createExtension("minImportance", String.valueOf(Collections.min(nonZeroImportances, comparator))))
.addExtensions(PMMLUtil.createExtension("maxImportance", String.valueOf(Collections.max(nonZeroImportances, comparator))));
}
miningSchema.addExtensions(PMMLUtil.createExtension(Extensions.FEATURE_IMPORTANCES, inlineTable));
}
}
}
private void encodeUnivariateStats(PMML pmml){
Map> modelUnivariateStats = getUnivariateStats();
if(modelUnivariateStats.isEmpty()){
return;
} // End if
if(modelUnivariateStats.containsKey(null)){
throw new IllegalStateException();
}
Collection>> entries = modelUnivariateStats.entrySet();
for(Map.Entry> entry : entries){
Model model = entry.getKey();
List univariateStats = entry.getValue();
Map fieldUnivariateStats = univariateStats.stream()
// XXX: UnivariateStats::requireField
.collect(Collectors.toMap(UnivariateStats::getField, Function.identity()));
MiningSchema miningSchema = model.requireMiningSchema();
if(miningSchema.hasMiningFields()){
List miningFields = miningSchema.getMiningFields();
for(MiningField miningField : miningFields){
String fieldName = miningField.getName();
UnivariateStats pmmlUnivariateStats = fieldUnivariateStats.get(fieldName);
if(pmmlUnivariateStats != null){
ModelStats modelStats = ModelUtil.ensureModelStats(model);
modelStats.addUnivariateStats(pmmlUnivariateStats);
}
}
}
}
}
static
private void transferValue(Map map, K left, K right){
V value = map.remove(left);
if(value != null){
map.put(right, value);
}
}
private static final Logger logger = LoggerFactory.getLogger(ModelEncoder.class);
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy