All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.DiscretizationUtil Maven / Gradle / Ivy

There is a newer version: 1.6.6
Show newest version
/*
 * Copyright (c) 2013 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see .
 */
package org.jpmml.evaluator;

import java.util.Collection;
import java.util.EnumMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableRangeMap;
import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Range;
import com.google.common.collect.RangeMap;
import com.google.common.collect.SetMultimap;
import com.google.common.collect.Table;
import com.google.common.collect.TreeRangeMap;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Discretize;
import org.dmg.pmml.DiscretizeBin;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.Interval;
import org.dmg.pmml.MapValues;
import org.dmg.pmml.OpType;
import org.jpmml.model.InvalidElementException;
import org.jpmml.model.UnsupportedAttributeException;

public class DiscretizationUtil {

	private DiscretizationUtil(){
	}

	static
	public FieldValue discretize(Discretize discretize, FieldValue value){
		Object result = discretize(discretize, value.asDouble());

		return FieldValueUtil.create(OpType.CATEGORICAL, discretize.getDataType(DataType.STRING), result);
	}

	static
	public Object discretize(Discretize discretize, Double value){
		RangeMap binRanges = CacheUtil.getValue(discretize, DiscretizationUtil.binRangeCache);

		Map.Entry, Object> entry = binRanges.getEntry(value);
		if(entry != null){
			return entry.getValue();
		}

		return discretize.getDefaultValue();
	}

	static
	public FieldValue mapValue(MapValues mapValues, Map values){
		String outputColumn = mapValues.requireOutputColumn();
		DataType dataType = mapValues.getDataType(DataType.STRING);

		InlineTable inlineTable = InlineTableUtil.getInlineTable(mapValues);
		if(inlineTable != null){
			Map row = match(inlineTable, values);

			if(row != null){
				Object result = row.get(outputColumn);

				if(result == null){
					throw new InvalidElementException(inlineTable);
				}

				return FieldValueUtil.create(OpType.CATEGORICAL, dataType, result);
			}
		}

		return FieldValueUtil.create(OpType.CATEGORICAL, dataType, mapValues.getDefaultValue());
	}

	static
	public Range toRange(Interval interval){
		Double leftMargin = NumberUtil.asDouble(interval.getLeftMargin());
		Double rightMargin = NumberUtil.asDouble(interval.getRightMargin());

		// "The leftMargin and rightMargin attributes are optional, but at least one value must be defined"
		if(leftMargin == null && rightMargin == null){
			throw new InvalidElementException(interval);
		} // End if

		if(leftMargin != null && rightMargin != null && NumberUtil.compare(leftMargin, rightMargin) > 0){
			throw new InvalidElementException(interval);
		}

		Interval.Closure closure = interval.requireClosure();
		switch(closure){
			case OPEN_OPEN:
				{
					if(leftMargin == null){
						return Range.lessThan(rightMargin);
					} else

					if(rightMargin == null){
						return Range.greaterThan(leftMargin);
					}

					return Range.open(leftMargin, rightMargin);
				}
			case OPEN_CLOSED:
				{
					if(leftMargin == null){
						return Range.atMost(rightMargin);
					} else

					if(rightMargin == null){
						return Range.greaterThan(leftMargin);
					}

					return Range.openClosed(leftMargin, rightMargin);
				}
			case CLOSED_OPEN:
				{
					if(leftMargin == null){
						return Range.lessThan(rightMargin);
					} else

					if(rightMargin == null){
						return Range.atLeast(leftMargin);
					}

					return Range.closedOpen(leftMargin, rightMargin);
				}
			case CLOSED_CLOSED:
				{
					if(leftMargin == null){
						return Range.atMost(rightMargin);
					} else

					if(rightMargin == null){
						return Range.atLeast(leftMargin);
					}

					return Range.closed(leftMargin, rightMargin);
				}
			default:
				throw new UnsupportedAttributeException(interval, closure);
		}
	}

	static
	private Map match(InlineTable inlineTable, Map values){
		Map rowFilters = CacheUtil.getValue(inlineTable, DiscretizationUtil.rowFilterCache);

		Set rows = null;

		Collection> entries = values.entrySet();
		for(Map.Entry entry : entries){
			String key = entry.getKey();
			FieldValue value = entry.getValue();

			RowFilter rowFilter = rowFilters.get(key);
			if(rowFilter == null){
				throw new InvalidElementException(inlineTable);
			}

			SetMultimap valueRowsMap = rowFilter.getValueRowsMap(value.getDataType());

			Set valueRows = valueRowsMap.get(FieldValueUtil.getValue(value));

			if(valueRows != null && !valueRows.isEmpty()){

				if(rows == null){
					rows = (entries.size() > 1 ? new HashSet<>(valueRows) : valueRows);
				} else

				{
					rows.retainAll(valueRows);
				} // End if

				if(rows.isEmpty()){
					return null;
				}
			} else

			{
				return null;
			}
		}

		if(rows != null && !rows.isEmpty()){
			Table content = InlineTableUtil.getContent(inlineTable);

			// "It is an error if the table entries used for matching are not unique"
			if(rows.size() != 1){
				throw new InvalidElementException(inlineTable);
			}

			Integer row = Iterables.getOnlyElement(rows);

			return content.row(row);
		}

		return null;
	}

	static
	private RangeMap parseDiscretize(Discretize discretize){
		RangeMap result = TreeRangeMap.create();

		List discretizeBins = discretize.getDiscretizeBins();
		for(DiscretizeBin discretizeBin : discretizeBins){
			Interval interval = discretizeBin.requireInterval();

			Range range = toRange(interval);
			Object binValue = discretizeBin.requireBinValue();

			result.put(range, binValue);
		}

		return result;
	}

	static
	private Map parseInlineTable(InlineTable inlineTable){
		Map result = new LinkedHashMap<>();

		Table table = InlineTableUtil.getContent(inlineTable);

		Set columns = table.columnKeySet();
		for(String column : columns){
			Map columnValues = table.column(column);

			RowFilter rowFilter = new RowFilter(columnValues);

			result.put(column, rowFilter);
		}

		return result;
	}

	static
	private class RowFilter {

		private Map columnValues = null;

		private Map> valueRowsMap = new EnumMap<>(DataType.class);


		private RowFilter(Map columnValues){
			setColumnValues(columnValues);
		}

		public SetMultimap getValueRowsMap(DataType dataType){
			SetMultimap result = this.valueRowsMap.get(dataType);

			if(result == null){
				result = ImmutableSetMultimap.copyOf(parseColumnValues(dataType));

				this.valueRowsMap.put(dataType, result);
			}

			return result;
		}

		private SetMultimap parseColumnValues(DataType dataType){
			Map columnValues = getColumnValues();

			SetMultimap result = HashMultimap.create();

			Collection> entries = columnValues.entrySet();
			for(Map.Entry entry : entries){
				Object value = TypeUtil.parseOrCast(dataType, entry.getValue());
				Integer row = entry.getKey();

				result.put(value, row);
			}

			return result;
		}

		public Map getColumnValues(){
			return this.columnValues;
		}

		private void setColumnValues(Map columnValues){
			this.columnValues = columnValues;
		}
	}

	private static final LoadingCache> binRangeCache = CacheUtil.buildLoadingCache(new CacheLoader>(){

		@Override
		public RangeMap load(Discretize discretize){
			return ImmutableRangeMap.copyOf(parseDiscretize(discretize));
		}
	});

	private static final LoadingCache> rowFilterCache = CacheUtil.buildLoadingCache(new CacheLoader>(){

		@Override
		public Map load(InlineTable inlineTable){
			return ImmutableMap.copyOf(parseInlineTable(inlineTable));
		}
	});
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy