All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.MeasureUtil Maven / Gradle / Ivy

/*
 * Copyright (c) 2013 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see .
 */
package org.jpmml.evaluator;

import java.util.BitSet;
import java.util.List;

import org.dmg.pmml.BinarySimilarity;
import org.dmg.pmml.Chebychev;
import org.dmg.pmml.CityBlock;
import org.dmg.pmml.CompareFunction;
import org.dmg.pmml.ComparisonField;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.Distance;
import org.dmg.pmml.Euclidean;
import org.dmg.pmml.Jaccard;
import org.dmg.pmml.Minkowski;
import org.dmg.pmml.PMMLAttributes;
import org.dmg.pmml.Similarity;
import org.dmg.pmml.SimpleMatching;
import org.dmg.pmml.SquaredEuclidean;
import org.dmg.pmml.Tanimoto;
import org.jpmml.model.InvalidAttributeException;
import org.jpmml.model.InvalidElementException;
import org.jpmml.model.UnsupportedAttributeException;
import org.jpmml.model.UnsupportedElementException;

public class MeasureUtil {

	private MeasureUtil(){
	}

	static
	public  Value evaluateSimilarity(ValueFactory valueFactory, ComparisonMeasure comparisonMeasure, List> comparisonFields, BitSet flags, BitSet referenceFlags){
		Similarity measure = TypeUtil.cast(Similarity.class, comparisonMeasure.requireMeasure());

		int a11 = 0;
		int a10 = 0;
		int a01 = 0;
		int a00 = 0;

		for(int i = 0, max = comparisonFields.size(); i < max; i++){

			if(flags.get(i)){

				if(referenceFlags.get(i)){
					a11 += 1;
				} else

				{
					a10 += 1;
				}
			} else

			{
				if(referenceFlags.get(i)){
					a01 += 1;
				} else

				{
					a00 += 1;
				}
			}
		}

		Value numerator = valueFactory.newValue();
		Value denominator = valueFactory.newValue();

		if(measure instanceof SimpleMatching){
			numerator.add(a11 + a00);
			denominator.add(a11 + a10 + a01 + a00);
		} else

		if(measure instanceof Jaccard){
			numerator.add(a11);
			denominator.add(a11 + a10 + a01);
		} else

		if(measure instanceof Tanimoto){
			numerator.add(a11 + a00);
			denominator
				.add(a11)
				.add(Numbers.DOUBLE_TWO, (a10 + a01))
				.add(a00);
		} else

		if(measure instanceof BinarySimilarity){
			BinarySimilarity binarySimilarity = (BinarySimilarity)measure;

			Number c00 = binarySimilarity.requireC00Parameter();
			Number c01 = binarySimilarity.requireC01Parameter();
			Number c10 = binarySimilarity.requireC10Parameter();
			Number c11 = binarySimilarity.requireC11Parameter();

			numerator
				.add(c11, a11)
				.add(c10, a10)
				.add(c01, a01)
				.add(c00, a00);

			Number d00 = binarySimilarity.requireD00Parameter();
			Number d01 = binarySimilarity.requireD01Parameter();
			Number d10 = binarySimilarity.requireD10Parameter();
			Number d11 = binarySimilarity.requireD11Parameter();

			denominator
				.add(d11, a11)
				.add(d10, a10)
				.add(d01, a01)
				.add(d00, a00);
		} else

		{
			throw new UnsupportedElementException(measure);
		} // End if

		if(denominator.isZero()){
			throw new UndefinedResultException();
		}

		return numerator.divide(denominator);
	}

	static
	public BitSet toBitSet(List values){
		BitSet result = new BitSet(values.size());

		for(int i = 0, max = values.size(); i < max; i++){
			FieldValue value = values.get(i);

			if(value.equalsValue(Boolean.FALSE)){
				result.set(i, false);
			} else

			if(value.equalsValue(Boolean.TRUE)){
				result.set(i, true);
			} else

			{
				throw new EvaluationException("Expected " + EvaluationException.formatValue(Boolean.FALSE) + " or " + EvaluationException.formatValue(Boolean.TRUE) + ", got " + EvaluationException.formatValue(value));
			}
		}

		return result;
	}

	static
	public  Value evaluateDistance(ValueFactory valueFactory, ComparisonMeasure comparisonMeasure, List> comparisonFields, List values, List referenceValues, Value adjustment){
		Distance measure = TypeUtil.cast(Distance.class, comparisonMeasure.requireMeasure());

		Number innerPower;
		Number outerPower;

		if(measure instanceof Euclidean){
			innerPower = outerPower = Numbers.DOUBLE_TWO;
		} else

		if(measure instanceof SquaredEuclidean){
			innerPower = Numbers.DOUBLE_TWO;
			outerPower = Numbers.DOUBLE_ONE;
		} else

		if(measure instanceof Chebychev || measure instanceof CityBlock){
			innerPower = outerPower = Numbers.DOUBLE_ONE;
		} else

		if(measure instanceof Minkowski){
			Minkowski minkowski = (Minkowski)measure;

			Number p = minkowski.requirePParameter();
			if(p.doubleValue() < 0d){
				throw new InvalidAttributeException(minkowski, PMMLAttributes.MINKOWSKI_PPARAMETER, p);
			}

			innerPower = outerPower = p;
		} else

		{
			throw new UnsupportedElementException(measure);
		}

		Vector distances = valueFactory.newVector(0);

		for(int i = 0, max = comparisonFields.size(); i < max; i++){
			ComparisonField comparisonField = comparisonFields.get(i);

			FieldValue value = values.get(i);
			if(FieldValueUtil.isMissing(value)){
				continue;
			}

			FieldValue referenceValue = referenceValues.get(i);

			Value distance = evaluateInnerFunction(valueFactory, comparisonMeasure, comparisonField, value, referenceValue, innerPower);

			distances.add(distance);
		}

		if(measure instanceof Euclidean || measure instanceof SquaredEuclidean || measure instanceof CityBlock || measure instanceof Minkowski){
			Value result = distances.sum()
				.multiply(adjustment)
				.inversePower(outerPower);

			return result;
		} else

		if(measure instanceof Chebychev){
			Value result = distances.max()
				.multiply(adjustment);

			return result;
		} else

		{
			throw new UnsupportedElementException(measure);
		}
	}

	static
	private  Value evaluateInnerFunction(ValueFactory valueFactory, ComparisonMeasure comparisonMeasure, ComparisonField comparisonField, FieldValue value, FieldValue referenceValue, Number power){
		CompareFunction compareFunction = comparisonField.getCompareFunction();

		if(compareFunction == null){
			compareFunction = comparisonMeasure.getCompareFunction();

			// The ComparisonMeasure element is limited to "attribute-less" comparison functions
			switch(compareFunction){
				case ABS_DIFF:
				case DELTA:
				case EQUAL:
					break;
				case GAUSS_SIM:
				case TABLE:
					throw new InvalidAttributeException(comparisonMeasure, compareFunction);
				default:
					throw new UnsupportedAttributeException(comparisonMeasure, compareFunction);
			}
		}

		Value distance;

		switch(compareFunction){
			case ABS_DIFF:
				{
					distance = valueFactory.newValue(value.asNumber())
						.subtract(referenceValue.asNumber())
						.abs();
				}
				break;
			case GAUSS_SIM:
				{
					Number similarityScale = comparisonField.getSimilarityScale();
					if(similarityScale == null){
						throw new InvalidElementException(comparisonField);
					}

					distance = valueFactory.newValue(value.asNumber())
						.subtract(referenceValue.asNumber())
						.gaussSim(similarityScale);
				}
				break;
			case DELTA:
				{
					boolean equals = (value).equalsValue(referenceValue);

					distance = valueFactory.newValue(equals ? Numbers.DOUBLE_ZERO : Numbers.DOUBLE_ONE);
				}
				break;
			case EQUAL:
				{
					boolean equals = (value).equalsValue(referenceValue);

					distance = valueFactory.newValue(equals ? Numbers.DOUBLE_ONE : Numbers.DOUBLE_ZERO);
				}
				break;
			case TABLE:
				throw new UnsupportedAttributeException(comparisonField, compareFunction);
			default:
				throw new UnsupportedAttributeException(comparisonField, compareFunction);
		}

		distance.power(power);

		Number fieldWeight = comparisonField.getFieldWeight();
		if(fieldWeight != null){
			distance.multiply(fieldWeight);
		}

		return distance;
	}

	static
	public  Value calculateAdjustment(ValueFactory valueFactory, List values){
		return calculateAdjustment(valueFactory, values, null);
	}

	static
	public  Value calculateAdjustment(ValueFactory valueFactory, List values, List adjustmentValues){
		Value sum = valueFactory.newValue();
		Value nonmissingSum = valueFactory.newValue();

		for(int i = 0, max = values.size(); i < max; i++){
			FieldValue value = values.get(i);
			Number adjustmentValue = (adjustmentValues != null ? adjustmentValues.get(i) : Numbers.DOUBLE_ONE);

			sum.add(adjustmentValue);

			if(!FieldValueUtil.isMissing(value)){
				nonmissingSum.add(adjustmentValue);
			}
		}

		if(nonmissingSum.isZero()){
			throw new UndefinedResultException();
		}

		return sum.divide(nonmissingSum);
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy