All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.regression.RegressionModelUtil Maven / Gradle / Ivy

/*
 * Copyright (c) 2017 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see .
 */
package org.jpmml.evaluator.regression;

import java.util.Iterator;

import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.evaluator.ValueUtil;

public class RegressionModelUtil {

	private RegressionModelUtil(){
	}

	static
	public  ValueMap computeBinomialProbabilities(ValueMap values, RegressionModel.NormalizationMethod normalizationMethod){

		if(values.size() != 2){
			throw new IllegalArgumentException();
		}

		Iterator> valueIt = values.iterator();

		Value firstValue = valueIt.next();

		// The probability of the first category is calculated
		normalizeBinaryLogisticClassificationResult(firstValue, normalizationMethod);

		Value secondValue = valueIt.next();

		// The probability of the second category is obtained by subtracting the probability of the first category from 1.0
		secondValue.residual(firstValue);

		return values;
	}

	static
	public  ValueMap computeMultinomialProbabilities(ValueMap values, RegressionModel.NormalizationMethod normalizationMethod){

		if(values.size() < 2){
			throw new IllegalArgumentException();
		}

		switch(normalizationMethod){
			case NONE:
				{
					Value sum = null;

					Iterator> valueIt = values.iterator();
					for(int i = 0, max = values.size() - 1; i < max; i++){
						Value value = valueIt.next();

						if(sum == null){
							sum = value.copy();
						} else

						{
							sum.add(value);
						}
					}

					Value lastValue = valueIt.next();

					lastValue.residual(sum);
				}
				break;
			// XXX: Non-standard behaviour
			case LOGIT:
				{
					for(Value value : values){
						value.inverseLogit();
					}
				}
				// Falls through
			case SIMPLEMAX:
				{
					ValueUtil.normalizeSimpleMax(values);
				}
				break;
			case SOFTMAX:
				{
					ValueUtil.normalizeSoftMax(values);
				}
				break;
			default:
				throw new IllegalArgumentException();
		}

		return values;
	}

	static
	public  ValueMap computeOrdinalProbabilities(ValueMap values, RegressionModel.NormalizationMethod normalizationMethod){

		if(values.size() < 2){
			throw new IllegalArgumentException();
		}

		switch(normalizationMethod){
			case NONE:
			case LOGIT:
			case PROBIT:
			case CLOGLOG:
			case LOGLOG:
			case CAUCHIT:
				{
					Value sum = null;

					Iterator> valueIt = values.iterator();
					for(int i = 0, max = values.size() - 1; i < max; i++){
						Value value = valueIt.next();

						normalizeBinaryLogisticClassificationResult(value, normalizationMethod);

						if(sum == null){
							sum = value.copy();
						} else

						{
							value.subtract(sum);

							sum.add(value);
						}
					}

					Value lastValue = valueIt.next();

					lastValue.residual(sum);
				}
				break;
			default:
				throw new IllegalArgumentException();
		}

		return values;
	}

	static
	public  Value normalizeRegressionResult(Value value, RegressionModel.NormalizationMethod normalizationMethod){

		switch(normalizationMethod){
			case NONE:
				return value;
			case SOFTMAX:
			case LOGIT:
				return value.inverseLogit();
			case EXP:
				return value.exp();
			case PROBIT:
				return value.inverseProbit();
			case CLOGLOG:
				return value.inverseCloglog();
			case LOGLOG:
				return value.inverseLoglog();
			case CAUCHIT:
				return value.inverseCauchit();
			default:
				throw new IllegalArgumentException();
		}
	}

	static
	public  Value normalizeBinaryLogisticClassificationResult(Value value, RegressionModel.NormalizationMethod normalizationMethod){

		switch(normalizationMethod){
			case NONE:
				return value.restrict(0d, 1d);
			case LOGIT:
				return value.inverseLogit();
			case PROBIT:
				return value.inverseProbit();
			case CLOGLOG:
				return value.inverseCloglog();
			case LOGLOG:
				return value.inverseLoglog();
			case CAUCHIT:
				return value.inverseCauchit();
			default:
				throw new IllegalArgumentException();
		}
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy