All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.ArgumentUtil Maven / Gradle / Ivy

There is a newer version: 1.7.2
Show newest version
/*
 * Copyright (c) 2013 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see .
 */
package org.jpmml.evaluator;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import com.google.common.base.Function;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableRangeSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Range;
import com.google.common.collect.RangeSet;
import com.google.common.collect.TreeRangeSet;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Interval;
import org.dmg.pmml.InvalidValueTreatmentMethodType;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutlierTreatmentMethodType;
import org.dmg.pmml.TypeDefinitionField;
import org.dmg.pmml.Value;

public class ArgumentUtil {

	private ArgumentUtil(){
	}

	@SuppressWarnings (
		value = {"unused"}
	)
	static
	public FieldValue prepare(DataField dataField, MiningField miningField, Object value){

		if(value != null){
			DataType dataType = dataField.getDataType();

			try {
				value = TypeUtil.parseOrCast(dataType, value);
			} catch(IllegalArgumentException iae){
				// Ignored
			}
		}

		outlierTreatment:
		if(isOutlier(dataField, miningField, value)){
			OutlierTreatmentMethodType outlierTreatmentMethod = miningField.getOutlierTreatment();

			switch(outlierTreatmentMethod){
				case AS_IS:
					break;
				case AS_MISSING_VALUES:
					value = null;
					break;
				case AS_EXTREME_VALUES:
					{
						Double lowValue = miningField.getLowValue();
						Double highValue = miningField.getHighValue();

						if(lowValue == null || highValue == null){
							throw new InvalidFeatureException(miningField);
						} // End if

						if((lowValue).compareTo(highValue) > 0){
							throw new InvalidFeatureException(miningField);
						}

						Double doubleValue = (Double)TypeUtil.parseOrCast(DataType.DOUBLE, value);

						if(TypeUtil.compare(DataType.DOUBLE, doubleValue, lowValue) < 0){
							value = lowValue;
						} else

						if(TypeUtil.compare(DataType.DOUBLE, doubleValue, highValue) > 0){
							value = highValue;
						}
					}
					break;
				default:
					throw new UnsupportedFeatureException(miningField, outlierTreatmentMethod);
			}
		} // End if

		missingValueTreatment:
		if(isMissing(dataField, value)){
			value = miningField.getMissingValueReplacement();

			if(value != null){
				break missingValueTreatment;
			}

			return null;
		} // End if

		invalidValueTreatment:
		if(isInvalid(dataField, miningField, value)){
			InvalidValueTreatmentMethodType invalidValueTreatmentMethod = miningField.getInvalidValueTreatment();

			switch(invalidValueTreatmentMethod){
				case RETURN_INVALID:
					throw new InvalidResultException(miningField);
				case AS_IS:
					break invalidValueTreatment;
				case AS_MISSING:
					{
						value = miningField.getMissingValueReplacement();
						if(value != null){
							break invalidValueTreatment;
						}

						return null;
					}
				default:
					throw new UnsupportedFeatureException(miningField, invalidValueTreatmentMethod);
			}
		}

		return FieldValueUtil.create(dataField, miningField, value);
	}

	static
	public boolean isOutlier(DataField dataField, MiningField miningField, Object value){

		if(value == null){
			return false;
		}

		List intervals = dataField.getIntervals();

		OpType opType = miningField.getOpType();
		if(opType == null){
			opType = dataField.getOpType();
		}

		switch(opType){
			case CONTINUOUS:
				{
					if(intervals.size() > 0){
						RangeSet validRange = CacheUtil.getValue(dataField, ArgumentUtil.validRangeCache);

						Range validRangeSpan = validRange.span();

						Double doubleValue = (Double)TypeUtil.parseOrCast(DataType.DOUBLE, value);

						return !validRangeSpan.contains(doubleValue);
					}
				}
				break;
			case CATEGORICAL:
			case ORDINAL:
				break;
			default:
				throw new UnsupportedFeatureException(miningField, opType);
		}

		return false;
	}

	static
	public boolean isMissing(DataField dataField, Object value){

		if(value == null){
			return true;
		}

		DataType dataType = dataField.getDataType();

		List fieldValues = dataField.getValues();
		for(Value fieldValue : fieldValues){
			Value.Property property = fieldValue.getProperty();

			switch(property){
				case MISSING:
					{
						boolean equals = equals(dataType, value, fieldValue.getValue());
						if(equals){
							return true;
						}
					}
					break;
				default:
					break;
			}
		}

		return false;
	}

	static
	public boolean isInvalid(DataField dataField, MiningField miningField, Object value){

		if(value == null){
			return false;
		}

		return !isValid(dataField, miningField, value);
	}

	@SuppressWarnings (
		value = "fallthrough"
	)
	static
	public boolean isValid(DataField dataField, MiningField miningField, Object value){

		if(value == null){
			return false;
		}

		DataType dataType = dataField.getDataType();

		List intervals = dataField.getIntervals();

		OpType opType = miningField.getOpType();
		if(opType == null){
			opType = dataField.getOpType();
		}

		switch(opType){
			case CONTINUOUS:
				{
					// "If intervals are present, then a value that is outside the intervals is considered invalid"
					if(intervals.size() > 0){
						RangeSet validRanges = CacheUtil.getValue(dataField, ArgumentUtil.validRangeCache);

						Double doubleValue = (Double)TypeUtil.parseOrCast(DataType.DOUBLE, value);

						return validRanges.contains(doubleValue);
					}
				}
				// Falls through
			case CATEGORICAL:
			case ORDINAL:
				{
					// "Intervals are not allowed for non-continuous fields"
					if(intervals.size() > 0){
						throw new InvalidFeatureException(dataField);
					}

					int validValueCount = 0;

					List fieldValues = dataField.getValues();
					for(Value fieldValue : fieldValues){
						Value.Property property = fieldValue.getProperty();

						switch(property){
							case VALID:
								{
									validValueCount += 1;

									boolean equals = equals(dataType, value, fieldValue.getValue());
									if(equals){
										return true;
									}
								}
								break;
							case INVALID:
							case MISSING:
								{
									boolean equals = equals(dataType, value, fieldValue.getValue());
									if(equals){
										return false;
									}
								}
								break;
							default:
								throw new UnsupportedFeatureException(fieldValue, property);
						}
					}

					// "If a field contains at least one Value element where the value of property is valid, then the set of Value elements completely defines the set of valid values"
					if(validValueCount > 0){
						return false;
					}

					// "Any value is valid by default"
					return true;
				}
			default:
				throw new UnsupportedFeatureException(miningField, opType);
		}
	}

	static
	public Value getValidValue(TypeDefinitionField field, Object value){
		DataType dataType = field.getDataType();

		List fieldValues = field.getValues();
		for(Value fieldValue : fieldValues){
			Value.Property property = fieldValue.getProperty();

			switch(property){
				case VALID:
					{
						boolean equals = equals(dataType, value, fieldValue.getValue());
						if(equals){
							return fieldValue;
						}
					}
					break;
				default:
					break;
			}
		}

		return null;
	}

	static
	public List getValidValues(TypeDefinitionField field){
		List fieldValues = field.getValues();
		if(fieldValues.isEmpty()){
			return Collections.emptyList();
		}

		List result = new ArrayList<>();

		for(Value fieldValue : fieldValues){
			Value.Property property = fieldValue.getProperty();

			switch(property){
				case VALID:
					result.add(fieldValue);
					break;
				default:
					break;
			}
		}

		return result;
	}

	static
	private boolean equals(DataType dataType, Object value, String referenceValue){

		try {
			return TypeUtil.equals(dataType, value, TypeUtil.parseOrCast(dataType, referenceValue));
		} catch(IllegalArgumentException iae){

			// The String representation of invalid or missing values (eg. "N/A") may not be parseable to the requested representation
			try {
				return TypeUtil.equals(DataType.STRING, value, referenceValue);
			} catch(TypeCheckException tce){
				// Ignored
			}

			throw iae;
		}
	}

	static
	public List getTargetCategories(TypeDefinitionField field){
		return CacheUtil.getValue(field, ArgumentUtil.targetCategoryCache);
	}

	static
	private RangeSet parseValidRanges(DataField dataField){
		RangeSet result = TreeRangeSet.create();

		List intervals = dataField.getIntervals();
		for(Interval interval : intervals){
			Range range = DiscretizationUtil.toRange(interval);

			result.add(range);
		}

		return result;
	}

	private static final LoadingCache> targetCategoryCache = CacheBuilder.newBuilder()
		.weakKeys()
		.build(new CacheLoader>(){

			@Override
			public List load(TypeDefinitionField field){
				List values = getValidValues(field);

				Function function = new Function(){

					@Override
					public String apply(Value value){
						String result = value.getValue();
						if(result == null){
							throw new InvalidFeatureException(value);
						}

						return result;
					}
				};

				return ImmutableList.copyOf(Iterables.transform(values, function));
			}
		});

	private static final LoadingCache> validRangeCache = CacheBuilder.newBuilder()
		.weakKeys()
		.build(new CacheLoader>(){

			@Override
			public RangeSet load(DataField dataField){
				return ImmutableRangeSet.copyOf(parseValidRanges(dataField));
			}
		});
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy