All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.xgboost.FeatureMap Maven / Gradle / Ivy

Go to download

Java library and command-line application for converting XGBoost models to PMML

There is a newer version: 1.8.7
Show newest version
/*
 * Copyright (c) 2016 Villu Ruusmann
 *
 * This file is part of JPMML-XGBoost
 *
 * JPMML-XGBoost is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-XGBoost is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-XGBoost.  If not, see .
 */
package org.jpmml.xgboost;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Value;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PMMLUtil;

public class FeatureMap {

	private List entries = new ArrayList<>();

	private Map> valueMap = new EnumMap<>(Value.Property.class);


	public FeatureMap(){
	}

	public List encodeFeatures(PMMLEncoder encoder){
		List result = new ArrayList<>();

		Set dataFields = new LinkedHashSet<>();

		List entries = getEntries();
		for(Entry entry : entries){
			FieldName name = FieldName.create(entry.getName());
			String value = entry.getValue();

			DataField dataField = encoder.getDataField(name);
			if(dataField == null){
				Entry.Type type = entry.getType();

				switch(type){
					case BINARY_INDICATOR:
						dataField = encoder.createDataField(name, OpType.CATEGORICAL, DataType.STRING);
						break;
					case FLOAT:
						dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.FLOAT);
						break;
					case INTEGER:
						dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.INTEGER);
						break;
					default:
						throw new IllegalArgumentException(String.valueOf(type));
				}
			}

			if(value != null){
				PMMLUtil.addValues(dataField, Collections.singletonList(value));
			}

			dataFields.add(dataField);

			Feature feature;

			OpType opType = dataField.getOpType();
			switch(opType){
				case CATEGORICAL:
					feature = new BinaryFeature(encoder, dataField, value);
					break;
				case CONTINUOUS:
					feature = new ContinuousFeature(encoder, dataField);
					break;
				default:
					throw new IllegalArgumentException("Expected categorical or continuous operational type, got " + opType.value() + " operational type");
			}

			result.add(feature);
		}

		Collection>> valueEntries = this.valueMap.entrySet();

		for(DataField dataField : dataFields){

			for(Map.Entry> valueEntry : valueEntries){
				PMMLUtil.addValues(dataField, valueEntry.getValue(), valueEntry.getKey());
			}
		}

		return result;
	}

	public void addEntry(String name, String type){
		String value = null;

		if(("i").equals(type)){
			int equals = name.indexOf('=');

			if(equals < 0){
				throw new IllegalArgumentException(name);
			}

			value = name.substring(equals + 1);
			name = name.substring(0, equals);
		}

		Entry entry = new Entry(name, value, Entry.Type.fromString(type));

		addEntry(entry);
	}

	public void addEntry(Entry entry){
		List entries = getEntries();

		entries.add(entry);
	}

	public List getEntries(){
		return this.entries;
	}

	public void addValidValue(String value){
		addValue(Value.Property.VALID, value);
	}

	public void addInvalidValue(String value){
		addValue(Value.Property.INVALID, value);
	}

	public void addMissingValue(String value){
		addValue(Value.Property.MISSING, value);
	}

	private void addValue(Value.Property property, String value){

		if(value == null){
			return;
		}

		List values = this.valueMap.get(property);
		if(values == null){
			values = new ArrayList<>();

			this.valueMap.put(property, values);
		}

		values.add(value);
	}

	static
	public class Entry {

		private String name = null;

		private String value = null;

		private Type type = null;


		public Entry(String name, String value, Type type){
			setName(name);
			setValue(value);
			setType(type);
		}

		public String getName(){
			return this.name;
		}

		private void setName(String name){
			this.name = Objects.requireNonNull(name);
		}

		public String getValue(){
			return this.value;
		}

		private void setValue(String value){
			this.value = value;
		}

		public Type getType(){
			return this.type;
		}

		private void setType(Type type){
			this.type = Objects.requireNonNull(type);
		}

		static
		public enum Type {
			BINARY_INDICATOR,
			FLOAT,
			INTEGER,
			;

			static
			public Type fromString(String string){

				switch(string){
					case "i":
						return Type.BINARY_INDICATOR;
					case "q":
						return Type.FLOAT;
					case "int":
						return Type.INTEGER;
					default:
						throw new IllegalArgumentException(string);
				}
			}
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy