All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.DiscretizationUtil Maven / Gradle / Ivy

/*
 * Copyright (c) 2013 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see .
 */
package org.jpmml.evaluator;

import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableRangeMap;
import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Range;
import com.google.common.collect.RangeMap;
import com.google.common.collect.SetMultimap;
import com.google.common.collect.Table;
import com.google.common.collect.TreeRangeMap;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Discretize;
import org.dmg.pmml.DiscretizeBin;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.Interval;
import org.dmg.pmml.MapValues;
import org.dmg.pmml.OpType;

public class DiscretizationUtil {

	private DiscretizationUtil(){
	}

	static
	public FieldValue discretize(Discretize discretize, FieldValue value){
		String result = discretize(discretize, value.asDouble());

		return FieldValueUtil.create(ExpressionUtil.getDataType(discretize, DataType.STRING), OpType.CATEGORICAL, result);
	}

	static
	public String discretize(Discretize discretize, Double value){
		RangeMap binRanges = CacheUtil.getValue(discretize, DiscretizationUtil.binRangeCache);

		Map.Entry, String> entry = binRanges.getEntry(value);
		if(entry != null){
			return entry.getValue();
		}

		return discretize.getDefaultValue();
	}

	static
	public FieldValue mapValue(MapValues mapValues, Map values){
		String outputColumn = mapValues.getOutputColumn();
		if(outputColumn == null){
			throw new MissingAttributeException(mapValues, PMMLAttributes.MAPVALUES_OUTPUTCOLUMN);
		}

		DataType dataType = ExpressionUtil.getDataType(mapValues, DataType.STRING);

		InlineTable inlineTable = InlineTableUtil.getInlineTable(mapValues);
		if(inlineTable != null){
			Map row = match(inlineTable, values);

			if(row != null){
				String result = row.get(outputColumn);

				if(result == null){
					throw new InvalidElementException(inlineTable);
				}

				return FieldValueUtil.create(dataType, OpType.CATEGORICAL, result);
			}
		}

		return FieldValueUtil.create(dataType, OpType.CATEGORICAL, mapValues.getDefaultValue());
	}

	static
	public Range toRange(Interval interval){
		Double leftMargin = interval.getLeftMargin();
		Double rightMargin = interval.getRightMargin();

		// "The leftMargin and rightMargin attributes are optional, but at least one value must be defined"
		if(leftMargin == null && rightMargin == null){
			throw new MissingAttributeException(interval, PMMLAttributes.INTERVAL_LEFTMARGIN);
		} // End if

		if(leftMargin != null && rightMargin != null && (leftMargin).compareTo(rightMargin) > 0){
			throw new InvalidElementException(interval);
		}

		Interval.Closure closure = interval.getClosure();
		if(closure == null){
			throw new MissingAttributeException(interval, PMMLAttributes.INTERVAL_CLOSURE);
		}

		switch(closure){
			case OPEN_OPEN:
				{
					if(leftMargin == null){
						return Range.lessThan(rightMargin);
					} else

					if(rightMargin == null){
						return Range.greaterThan(leftMargin);
					}

					return Range.open(leftMargin, rightMargin);
				}
			case OPEN_CLOSED:
				{
					if(leftMargin == null){
						return Range.atMost(rightMargin);
					} else

					if(rightMargin == null){
						return Range.greaterThan(leftMargin);
					}

					return Range.openClosed(leftMargin, rightMargin);
				}
			case CLOSED_OPEN:
				{
					if(leftMargin == null){
						return Range.lessThan(rightMargin);
					} else

					if(rightMargin == null){
						return Range.atLeast(leftMargin);
					}

					return Range.closedOpen(leftMargin, rightMargin);
				}
			case CLOSED_CLOSED:
				{
					if(leftMargin == null){
						return Range.atMost(rightMargin);
					} else

					if(rightMargin == null){
						return Range.atLeast(leftMargin);
					}

					return Range.closed(leftMargin, rightMargin);
				}
			default:
				throw new UnsupportedAttributeException(interval, closure);
		}
	}

	static
	private Map match(InlineTable inlineTable, Map values){
		Map rowFilters = CacheUtil.getValue(inlineTable, DiscretizationUtil.rowFilterCache);

		Set rows = null;

		Collection> entries = values.entrySet();
		for(Map.Entry entry : entries){
			String key = entry.getKey();
			FieldValue value = entry.getValue();

			RowFilter rowFilter = rowFilters.get(key);
			if(rowFilter == null){
				throw new InvalidElementException(inlineTable);
			}

			Map> columnRowMap = rowFilter.getValueMapping(value.getDataType(), value.getOpType());

			Set columnRows = columnRowMap.get(value);

			if(columnRows != null && columnRows.size() > 0){

				if(rows == null){
					rows = (entries.size() > 1 ? new HashSet<>(columnRows) : columnRows);
				} else

				{
					rows.retainAll(columnRows);
				} // End if

				if(rows.isEmpty()){
					return null;
				}
			} else

			{
				return null;
			}
		}

		if(rows != null && rows.size() > 0){
			Table content = InlineTableUtil.getContent(inlineTable);

			// "It is an error if the table entries used for matching are not unique"
			if(rows.size() != 1){
				throw new InvalidElementException(inlineTable);
			}

			Integer row = Iterables.getOnlyElement(rows);

			return content.row(row);
		}

		return null;
	}

	static
	private RangeMap parseDiscretize(Discretize discretize){
		RangeMap result = TreeRangeMap.create();

		List discretizeBins = discretize.getDiscretizeBins();
		for(DiscretizeBin discretizeBin : discretizeBins){
			Interval interval = discretizeBin.getInterval();
			if(interval == null){
				throw new MissingAttributeException(discretizeBin, PMMLAttributes.DISCRETIZEBIN_INTERVAL);
			}

			Range range = toRange(interval);

			String binValue = discretizeBin.getBinValue();
			if(binValue == null){
				throw new MissingAttributeException(discretizeBin, PMMLAttributes.DISCRETIZEBIN_BINVALUE);
			}

			result.put(range, binValue);
		}

		return result;
	}

	static
	private Map parseInlineTable(InlineTable inlineTable){
		Map result = new LinkedHashMap<>();

		Table table = InlineTableUtil.getContent(inlineTable);

		Set columns = table.columnKeySet();
		for(String column : columns){
			Map columnValues = table.column(column);

			RowFilter rowFilter = new RowFilter(columnValues);

			result.put(column, rowFilter);
		}

		return result;
	}

	static
	private class RowFilter implements HasParsedValueMapping> {

		private Map columnValues = null;

		private SetMultimap parsedColumnValues = null;


		private RowFilter(Map columnValues){
			setColumnValues(columnValues);
		}

		@SuppressWarnings (
			value = {"rawtypes", "unchecked"}
		)
		@Override
		public Map> getValueMapping(DataType dataType, OpType opType){

			if(this.parsedColumnValues == null){
				this.parsedColumnValues = ImmutableSetMultimap.copyOf(parseColumnValues(dataType, opType));
			}

			return (Map)this.parsedColumnValues.asMap();
		}

		private SetMultimap parseColumnValues(DataType dataType, OpType opType){
			SetMultimap result = HashMultimap.create();

			Map columnValues = getColumnValues();

			Collection> entries = columnValues.entrySet();
			for(Map.Entry entry : entries){
				FieldValue value = parse(dataType, opType, entry.getValue());
				Integer row = entry.getKey();

				result.put(value, row);
			}

			return result;
		}

		public Map getColumnValues(){
			return this.columnValues;
		}

		private void setColumnValues(Map columnValues){
			this.columnValues = columnValues;
		}
	}

	private static final LoadingCache> binRangeCache = CacheUtil.buildLoadingCache(new CacheLoader>(){

		@Override
		public RangeMap load(Discretize discretize){
			return ImmutableRangeMap.copyOf(parseDiscretize(discretize));
		}
	});

	private static final LoadingCache> rowFilterCache = CacheUtil.buildLoadingCache(new CacheLoader>(){

		@Override
		public Map load(InlineTable inlineTable){
			return ImmutableMap.copyOf(parseInlineTable(inlineTable));
		}
	});
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy