org.jpmml.xgboost.FeatureMap Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of jpmml-xgboost Show documentation
Show all versions of jpmml-xgboost Show documentation
Java library and command-line application for converting XGBoost models to PMML
/*
* Copyright (c) 2016 Villu Ruusmann
*
* This file is part of JPMML-XGBoost
*
* JPMML-XGBoost is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-XGBoost is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-XGBoost. If not, see .
*/
package org.jpmml.xgboost;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Value;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PMMLUtil;
public class FeatureMap {
private List entries = new ArrayList<>();
private Map> valueMap = new EnumMap<>(Value.Property.class);
public FeatureMap(){
}
public List encodeFeatures(PMMLEncoder encoder){
List result = new ArrayList<>();
Set dataFields = new LinkedHashSet<>();
List entries = getEntries();
for(Entry entry : entries){
FieldName name = FieldName.create(entry.getName());
String value = entry.getValue();
DataField dataField = encoder.getDataField(name);
if(dataField == null){
Entry.Type type = entry.getType();
switch(type){
case BINARY_INDICATOR:
dataField = encoder.createDataField(name, OpType.CATEGORICAL, DataType.STRING);
break;
case FLOAT:
dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.FLOAT);
break;
case INTEGER:
dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.INTEGER);
break;
default:
throw new IllegalArgumentException(String.valueOf(type));
}
}
if(value != null){
PMMLUtil.addValues(dataField, Collections.singletonList(value));
}
dataFields.add(dataField);
Feature feature;
OpType opType = dataField.getOpType();
switch(opType){
case CATEGORICAL:
feature = new BinaryFeature(encoder, dataField, value);
break;
case CONTINUOUS:
feature = new ContinuousFeature(encoder, dataField);
break;
default:
throw new IllegalArgumentException("Expected categorical or continuous operational type, got " + opType.value() + " operational type");
}
result.add(feature);
}
Collection>> valueEntries = this.valueMap.entrySet();
for(DataField dataField : dataFields){
for(Map.Entry> valueEntry : valueEntries){
PMMLUtil.addValues(dataField, valueEntry.getValue(), valueEntry.getKey());
}
}
return result;
}
public void addEntry(String name, String type){
String value = null;
if(("i").equals(type)){
int equals = name.indexOf('=');
if(equals < 0){
throw new IllegalArgumentException(name);
}
value = name.substring(equals + 1);
name = name.substring(0, equals);
}
Entry entry = new Entry(name, value, Entry.Type.fromString(type));
addEntry(entry);
}
public void addEntry(Entry entry){
List entries = getEntries();
entries.add(entry);
}
public List getEntries(){
return this.entries;
}
public void addValidValue(String value){
addValue(Value.Property.VALID, value);
}
public void addInvalidValue(String value){
addValue(Value.Property.INVALID, value);
}
public void addMissingValue(String value){
addValue(Value.Property.MISSING, value);
}
private void addValue(Value.Property property, String value){
if(value == null){
return;
}
List values = this.valueMap.get(property);
if(values == null){
values = new ArrayList<>();
this.valueMap.put(property, values);
}
values.add(value);
}
static
public class Entry {
private String name = null;
private String value = null;
private Type type = null;
public Entry(String name, String value, Type type){
setName(name);
setValue(value);
setType(type);
}
public String getName(){
return this.name;
}
private void setName(String name){
this.name = Objects.requireNonNull(name);
}
public String getValue(){
return this.value;
}
private void setValue(String value){
this.value = value;
}
public Type getType(){
return this.type;
}
private void setType(Type type){
this.type = Objects.requireNonNull(type);
}
static
public enum Type {
BINARY_INDICATOR,
FLOAT,
INTEGER,
;
static
public Type fromString(String string){
switch(string){
case "i":
return Type.BINARY_INDICATOR;
case "q":
return Type.FLOAT;
case "int":
return Type.INTEGER;
default:
throw new IllegalArgumentException(string);
}
}
}
}
}