All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.GeneralRegressionModelEvaluator Maven / Gradle / Ivy

There is a newer version: 1.6.11
Show newest version
/*
 * Copyright (c) 2013 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see .
 */
package org.jpmml.evaluator;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.google.common.base.Function;
import com.google.common.base.Predicate;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
import com.google.common.collect.Sets;
import org.dmg.pmml.BaseCumHazardTables;
import org.dmg.pmml.BaselineCell;
import org.dmg.pmml.BaselineStratum;
import org.dmg.pmml.Categories;
import org.dmg.pmml.Category;
import org.dmg.pmml.CumulativeLinkFunctionType;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.GeneralRegressionModel;
import org.dmg.pmml.LinkFunctionType;
import org.dmg.pmml.Matrix;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PCell;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PPCell;
import org.dmg.pmml.PPMatrix;
import org.dmg.pmml.ParamMatrix;
import org.dmg.pmml.Parameter;
import org.dmg.pmml.ParameterCell;
import org.dmg.pmml.ParameterList;
import org.dmg.pmml.Predictor;
import org.dmg.pmml.PredictorList;

public class GeneralRegressionModelEvaluator extends ModelEvaluator {

	transient
	private BiMap parameterRegistry = null;

	transient
	private Map> ppMatrixMap = null;

	transient
	private Map> paramMatrixMap = null;

	transient
	private List targetCategories = null;


	public GeneralRegressionModelEvaluator(PMML pmml){
		super(pmml, GeneralRegressionModel.class);
	}

	public GeneralRegressionModelEvaluator(PMML pmml, GeneralRegressionModel generalRegressionModel){
		super(pmml, generalRegressionModel);
	}

	@Override
	public String getSummary(){
		GeneralRegressionModel generalRegressionModel = getModel();

		GeneralRegressionModel.ModelType modelType = generalRegressionModel.getModelType();
		switch(modelType){
			case COX_REGRESSION:
				return "Cox regression";
			default:
				return "General regression";
		}
	}

	@Override
	public Map evaluate(ModelEvaluationContext context){
		GeneralRegressionModel generalRegressionModel = getModel();
		if(!generalRegressionModel.isScorable()){
			throw new InvalidResultException(generalRegressionModel);
		}

		Map predictions;

		MiningFunctionType miningFunction = generalRegressionModel.getFunctionName();
		switch(miningFunction){
			case REGRESSION:
				predictions = evaluateRegression(context);
				break;
			case CLASSIFICATION:
				predictions = evaluateClassification(context);
				break;
			default:
				throw new UnsupportedFeatureException(generalRegressionModel, miningFunction);
		}

		return OutputUtil.evaluate(predictions, context);
	}

	private Map evaluateRegression(ModelEvaluationContext context){
		GeneralRegressionModel generalRegressionModel = getModel();

		GeneralRegressionModel.ModelType modelType = generalRegressionModel.getModelType();
		switch(modelType){
			case COX_REGRESSION:
				return evaluateCoxRegression(context);
			default:
				return evaluateGeneralRegression(context);
		}
	}

	private Map evaluateCoxRegression(ModelEvaluationContext context){
		GeneralRegressionModel generalRegressionModel = getModel();

		BaseCumHazardTables baseCumHazardTables = generalRegressionModel.getBaseCumHazardTables();
		if(baseCumHazardTables == null){
			throw new InvalidFeatureException(generalRegressionModel);
		}

		FieldName targetField = getTargetField();

		List baselineCells;

		Double maxTime;

		FieldName baselineStrataVariable = generalRegressionModel.getBaselineStrataVariable();

		if(baselineStrataVariable != null){
			FieldValue value = getVariable(baselineStrataVariable, context);

			BaselineStratum baselineStratum = getBaselineStratum(baseCumHazardTables, value);

			// "If the value does not have a corresponding BaselineStratum element, then the result is a missing value"
			if(baselineStratum == null){
				return null;
			}

			baselineCells = baselineStratum.getBaselineCells();

			maxTime = baselineStratum.getMaxTime();
		} else

		{
			baselineCells = baseCumHazardTables.getBaselineCells();

			maxTime = baseCumHazardTables.getMaxTime();
			if(maxTime == null){
				throw new InvalidFeatureException(baseCumHazardTables);
			}
		}

		Comparator comparator = new Comparator(){

			@Override
			public int compare(BaselineCell left, BaselineCell right){
				return Double.compare(left.getTime(), right.getTime());
			}
		};

		Ordering ordering = Ordering.from(comparator);

		double baselineCumHazard;

		FieldName startTimeVariable = generalRegressionModel.getStartTimeVariable();
		FieldName endTimeVariable = generalRegressionModel.getEndTimeVariable();

		if(endTimeVariable != null){
			BaselineCell minBaselineCell = ordering.min(baselineCells);

			Double minTime = minBaselineCell.getTime();

			final
			FieldValue value = getVariable(endTimeVariable, context);

			FieldValue minTimeValue = FieldValueUtil.create(DataType.DOUBLE, OpType.CONTINUOUS, minTime);

			// "If the value is less than the minimum time, then cumulative hazard is 0 and predicted survival is 1"
			if(value.compareToValue(minTimeValue) < 0){
				return Collections.singletonMap(targetField, Values.DOUBLE_ZERO);
			}

			FieldValue maxTimeValue = FieldValueUtil.create(DataType.DOUBLE, OpType.CONTINUOUS, maxTime);

			// "If the value is greater than the maximum time, then the result is a missing value"
			if(value.compareToValue(maxTimeValue) > 0){
				return null;
			}

			Predicate predicate = new Predicate(){

				private double time = (value.asNumber()).doubleValue();


				@Override
				public boolean apply(BaselineCell baselineCell){
					return (baselineCell.getTime() <= this.time);
				}
			};

			// "Select the BaselineCell element that has the largest time attribute value that is not greater than the value"
			BaselineCell baselineCell = ordering.max(Iterables.filter(baselineCells, predicate));

			baselineCumHazard = baselineCell.getCumHazard();
		} else

		{
			throw new InvalidFeatureException(generalRegressionModel);
		}

		Double r = computeDotProduct(context);

		Double s = computeReferencePoint();

		if(r == null || s == null){
			return null;
		}

		Double cumHazard = baselineCumHazard * Math.exp(r - s);

		return Collections.singletonMap(targetField, cumHazard);
	}

	private Map evaluateGeneralRegression(ModelEvaluationContext context){
		GeneralRegressionModel generalRegressionModel = getModel();

		Double result = computeDotProduct(context);
		if(result == null){
			return TargetUtil.evaluateRegressionDefault(context);
		}

		GeneralRegressionModel.ModelType modelType = generalRegressionModel.getModelType();
		switch(modelType){
			case REGRESSION:
			case GENERAL_LINEAR:
				break;
			case GENERALIZED_LINEAR:
				result = computeLink(result, context);
				break;
			default:
				throw new UnsupportedFeatureException(generalRegressionModel, modelType);
		}

		return TargetUtil.evaluateRegression(result, context);
	}

	private Map evaluateClassification(ModelEvaluationContext context){
		GeneralRegressionModel generalRegressionModel = getModel();

		List targetCategories = getTargetCategories();

		Map> ppMatrixMap = getPPMatrixMap();

		Map> paramMatrixMap = getParamMatrixMap();

		GeneralRegressionModel.ModelType modelType = generalRegressionModel.getModelType();

		ProbabilityDistribution result = new ProbabilityDistribution();

		double previousValue = 0d;

		for(int i = 0; i < targetCategories.size(); i++){
			String targetCategory = targetCategories.get(i);

			double value;

			// Categories from the first category to the second-to-last category
			if(i < (targetCategories.size() - 1)){
				Map parameterPredictorRows;

				if(ppMatrixMap.isEmpty()){
					parameterPredictorRows = Collections.emptyMap();
				} else

				{
					parameterPredictorRows = ppMatrixMap.get(targetCategory);
					if(parameterPredictorRows == null){
						parameterPredictorRows = ppMatrixMap.get(null);
					} // End if

					if(parameterPredictorRows == null){
						throw new InvalidFeatureException(generalRegressionModel.getPPMatrix());
					}
				}

				Iterable parameterCells;

				switch(modelType){
					case GENERALIZED_LINEAR:
					case MULTINOMIAL_LOGISTIC:
						// PCell elements must have non-null targetCategory attribute in case of multinomial categories, but can do without in case of binomial categories
						parameterCells = paramMatrixMap.get(targetCategory);
						if(parameterCells == null && targetCategories.size() == 2){
							parameterCells = paramMatrixMap.get(null);
						} // End if

						if(parameterCells == null){
							throw new InvalidFeatureException(generalRegressionModel.getParamMatrix());
						}
						break;
					case ORDINAL_MULTINOMIAL:
						// "ParamMatrix specifies different values for the intercept parameter: one for each target category except one"
						List interceptCells = paramMatrixMap.get(targetCategory);
						if(interceptCells == null || interceptCells.size() != 1){
							throw new InvalidFeatureException(generalRegressionModel.getParamMatrix());
						}

						// "Values for all other parameters are constant across all target variable values"
						parameterCells = paramMatrixMap.get(null);
						if(parameterCells == null){
							throw new InvalidFeatureException(generalRegressionModel.getParamMatrix());
						}

						parameterCells = Iterables.concat(interceptCells, parameterCells);
						break;
					default:
						throw new UnsupportedFeatureException(generalRegressionModel, modelType);
				}

				Double dotProduct = computeDotProduct(parameterCells, parameterPredictorRows, context);
				if(dotProduct == null){
					return TargetUtil.evaluateClassificationDefault(context);
				}

				value = dotProduct;

				switch(modelType){
					case GENERALIZED_LINEAR:
						value = computeLink(value, context);
						break;
					case MULTINOMIAL_LOGISTIC:
						value = Math.exp(value);
						break;
					case ORDINAL_MULTINOMIAL:
						value = computeCumulativeLink(value, context);
						break;
					default:
						throw new UnsupportedFeatureException(generalRegressionModel, modelType);
				}
			} else

			// The last category
			{
				switch(modelType){
					case GENERALIZED_LINEAR:
						value = (1d - previousValue);
						break;
					case MULTINOMIAL_LOGISTIC:
						// "By convention, the vector of Parameter estimates for the last category is 0"
						value = Math.exp(0d);
						break;
					case ORDINAL_MULTINOMIAL:
						value = 1d;
						break;
					default:
						throw new UnsupportedFeatureException(generalRegressionModel, modelType);
				}
			}

			switch(modelType){
				case GENERALIZED_LINEAR:
				case MULTINOMIAL_LOGISTIC:
					result.put(targetCategory, value);
					break;
				case ORDINAL_MULTINOMIAL:
					result.put(targetCategory, (value - previousValue));
					break;
				default:
					throw new UnsupportedFeatureException(generalRegressionModel, modelType);
			}

			previousValue = value;
		}

		switch(modelType){
			case GENERALIZED_LINEAR:
				break;
			case MULTINOMIAL_LOGISTIC:
				result.normalizeValues();
				break;
			case ORDINAL_MULTINOMIAL:
				break;
			default:
				throw new UnsupportedFeatureException(generalRegressionModel, modelType);
		}

		return TargetUtil.evaluateClassification(result, context);
	}

	private Double computeDotProduct(EvaluationContext context){
		GeneralRegressionModel generalRegressionModel = getModel();

		Map> ppMatrixMap = getPPMatrixMap();

		Map parameterPredictorRows;

		if(ppMatrixMap.isEmpty()){
			parameterPredictorRows = Collections.emptyMap();
		} else

		{
			parameterPredictorRows = ppMatrixMap.get(null);
			if(parameterPredictorRows == null){
				throw new InvalidFeatureException(generalRegressionModel.getPPMatrix());
			}
		}

		Map> paramMatrixMap = getParamMatrixMap();
		if(paramMatrixMap.size() != 1 || !paramMatrixMap.containsKey(null)){
			throw new InvalidFeatureException(generalRegressionModel.getParamMatrix());
		}

		List parameterCells = paramMatrixMap.get(null);

		return computeDotProduct(parameterCells, parameterPredictorRows, context);
	}

	private Double computeDotProduct(Iterable parameterCells, Map parameterPredictorRows, EvaluationContext context){
		double sum = 0d;

		int count = 0;

		for(PCell parameterCell : parameterCells){
			double value;

			Row parameterPredictorRow = parameterPredictorRows.get(parameterCell.getParameterName());
			if(parameterPredictorRow != null){
				Double x = parameterPredictorRow.evaluate(context);
				if(x == null){
					return null;
				}

				value = (x * parameterCell.getBeta());
			} else

			// The row is empty
			{
				value = parameterCell.getBeta();
			}

			sum += value;

			count++;
		}

		if(count == 0){
			return null;
		}

		return sum;
	}

	private Double computeReferencePoint(){
		GeneralRegressionModel generalRegressionModel = getModel();

		BiMap parameters = getParameterRegistry();

		Map> paramMatrixMap = getParamMatrixMap();
		if(paramMatrixMap.size() != 1 || !paramMatrixMap.containsKey(null)){
			throw new InvalidFeatureException(generalRegressionModel.getParamMatrix());
		}

		Iterable parameterCells = paramMatrixMap.get(null);

		Double sum = null;

		for(PCell parameterCell : parameterCells){
			Parameter parameter = parameters.get(parameterCell.getParameterName());
			if(parameter == null){
				return null;
			}

			double value = (parameter.getReferencePoint() * parameterCell.getBeta());

			sum = (sum != null ? (sum + value) : value);
		}

		return sum;
	}

	private double computeLink(double value, EvaluationContext context){
		GeneralRegressionModel generalRegressionModel = getModel();

		LinkFunctionType linkFunction = generalRegressionModel.getLinkFunction();
		if(linkFunction == null){
			throw new InvalidFeatureException(generalRegressionModel);
		}

		Double a = getOffset(generalRegressionModel, context);
		Integer b = getTrials(generalRegressionModel, context);

		Double c = generalRegressionModel.getDistParameter();
		Double d = generalRegressionModel.getLinkParameter();

		switch(linkFunction){
			case CLOGLOG:
				return (1d - Math.exp(-Math.exp(value + a))) * b;
			case IDENTITY:
				return (value + a) * b;
			case LOG:
				return Math.exp(value + a) * b;
			case LOGC:
				return (1d - Math.exp(value + a)) * b;
			case LOGIT:
				return (1d / (1d + Math.exp(-(value + a)))) * b;
			case LOGLOG:
				return Math.exp(-Math.exp(-(value + a))) * b;
			case NEGBIN:
				if(c == null){
					throw new InvalidFeatureException(generalRegressionModel);
				}
				return (1d / (c * (Math.exp(-(value + a)) - 1d))) * b;
			case ODDSPOWER:
				if(d == null){
					throw new InvalidFeatureException(generalRegressionModel);
				} // End if

				if(d < 0d || d > 0d){
					return (1d / (1d + Math.pow(1d + d * (value + a), -(1d / d)))) * b;
				}
				return (1d / (1d + Math.exp(-(value + a)))) * b;
			case POWER:
				if(d == null){
					throw new InvalidFeatureException(generalRegressionModel);
				} // End if

				if(d < 0d || d > 0d){
					return Math.pow(value + a, 1d / d) * b;
				}
				return Math.exp(value + a) * b;
			case PROBIT:
				return NormalDistributionUtil.cumulativeProbability(value + a) * b;
			default:
				throw new UnsupportedFeatureException(generalRegressionModel, linkFunction);
		}
	}

	private double computeCumulativeLink(double value, EvaluationContext context){
		GeneralRegressionModel generalRegressionModel = getModel();

		CumulativeLinkFunctionType cumulativeLinkFunction = generalRegressionModel.getCumulativeLinkFunction();
		if(cumulativeLinkFunction == null){
			throw new InvalidFeatureException(generalRegressionModel);
		}

		Double a = getOffset(generalRegressionModel, context);

		switch(cumulativeLinkFunction){
			case LOGIT:
				return 1d / (1d + Math.exp(-(value + a)));
			case PROBIT:
				return NormalDistributionUtil.cumulativeProbability(value + a);
			case CLOGLOG:
				return 1d - Math.exp(-Math.exp(value + a));
			case LOGLOG:
				return Math.exp(-Math.exp(-(value + a)));
			case CAUCHIT:
				return 0.5d + (1d / Math.PI) * Math.atan(value + a);
			default:
				throw new UnsupportedFeatureException(generalRegressionModel, cumulativeLinkFunction);
		}
	}

	public BiMap getParameterRegistry(){

		if(this.parameterRegistry == null){
			this.parameterRegistry = getValue(GeneralRegressionModelEvaluator.parameterCache);
		}

		return this.parameterRegistry;
	}

	/**
	 * 

* A PPMatrix element may encode zero or more matrices. * Regression models return a singleton map, whereas classification models * may return a singleton map or a multi-valued map, which overrides the default * matrix for one or more target categories. *

* *

* The default matrix is mapped to the null key. *

* * @return A map of predictor-to-parameter correlation matrices. */ private Map> getPPMatrixMap(){ if(this.ppMatrixMap == null){ this.ppMatrixMap = getValue(GeneralRegressionModelEvaluator.ppMatrixCache); } return this.ppMatrixMap; } /** * @return A map of parameter matrices. */ private Map> getParamMatrixMap(){ if(this.paramMatrixMap == null){ this.paramMatrixMap = getValue(GeneralRegressionModelEvaluator.paramMatrixCache); } return this.paramMatrixMap; } private List getTargetCategories(){ if(this.targetCategories == null){ this.targetCategories = ImmutableList.copyOf(parseTargetCategories()); } return this.targetCategories; } private List parseTargetCategories(){ GeneralRegressionModel generalRegressionModel = getModel(); FieldName targetField = getTargetField(); DataField dataField = getDataField(targetField); if(dataField == null){ throw new MissingFieldException(targetField, generalRegressionModel); } OpType opType = dataField.getOpType(); switch(opType){ case CONTINUOUS: throw new InvalidFeatureException(dataField); case CATEGORICAL: case ORDINAL: break; default: throw new UnsupportedFeatureException(dataField, opType); } List targetCategories = FieldValueUtil.getTargetCategories(dataField); if(targetCategories.size() > 0 && targetCategories.size() < 2){ throw new InvalidFeatureException(dataField); } String targetReferenceCategory = generalRegressionModel.getTargetReferenceCategory(); GeneralRegressionModel.ModelType modelType = generalRegressionModel.getModelType(); switch(modelType){ case GENERALIZED_LINEAR: case MULTINOMIAL_LOGISTIC: if(targetReferenceCategory == null){ Predicate filter = new Predicate(){ private Map> paramMatrixMap = getParamMatrixMap(); @Override public boolean apply(String string){ return !this.paramMatrixMap.containsKey(string); } }; // "The reference category is the one from DataDictionary that does not appear in the ParamMatrix" Set targetReferenceCategories = Sets.newLinkedHashSet(Iterables.filter(targetCategories, filter)); if(targetReferenceCategories.size() != 1){ throw new InvalidFeatureException(generalRegressionModel.getParamMatrix()); } targetReferenceCategory = Iterables.getOnlyElement(targetReferenceCategories); } break; case ORDINAL_MULTINOMIAL: break; default: throw new UnsupportedFeatureException(generalRegressionModel, modelType); } if(targetReferenceCategory != null){ targetCategories = new ArrayList<>(targetCategories); // Move the element from any position to the last position if(targetCategories.remove(targetReferenceCategory)){ targetCategories.add(targetReferenceCategory); } } return targetCategories; } static private Double getOffset(GeneralRegressionModel generalRegressionModel, EvaluationContext context){ FieldName offsetVariable = generalRegressionModel.getOffsetVariable(); if(offsetVariable != null){ FieldValue value = getVariable(offsetVariable, context); return value.asDouble(); } Double offsetValue = generalRegressionModel.getOffsetValue(); if(offsetValue != null){ return offsetValue; } return Values.DOUBLE_ZERO; } static private Integer getTrials(GeneralRegressionModel generalRegressionModel, EvaluationContext context){ FieldName trialsVariable = generalRegressionModel.getTrialsVariable(); if(trialsVariable != null){ FieldValue value = getVariable(trialsVariable, context); return value.asInteger(); } Integer trialsValue = generalRegressionModel.getTrialsValue(); if(trialsValue != null){ return trialsValue; } return Values.INTEGER_ONE; } static private FieldValue getVariable(FieldName name, EvaluationContext context){ FieldValue value = context.evaluate(name); if(value == null){ throw new MissingValueException(name); } return value; } static private BaselineStratum getBaselineStratum(BaseCumHazardTables baseCumHazardTables, FieldValue value){ if(baseCumHazardTables instanceof HasParsedValueMapping){ HasParsedValueMapping hasParsedValueMapping = (HasParsedValueMapping)baseCumHazardTables; return (BaselineStratum)value.getMapping(hasParsedValueMapping); } List baselineStrata = baseCumHazardTables.getBaselineStrata(); for(BaselineStratum baselineStratum : baselineStrata){ if(value.equalsString(baselineStratum.getValue())){ return baselineStratum; } } return null; } static private BiMap parseParameterRegistry(ParameterList parameterList){ BiMap result = HashBiMap.create(); if(!parameterList.hasParameters()){ return result; } List parameters = parameterList.getParameters(); for(Parameter parameter : parameters){ result.put(parameter.getName(), parameter); } return result; } static private BiMap parsePredictorRegistry(PredictorList predictorList){ BiMap result = HashBiMap.create(); if(predictorList == null || !predictorList.hasPredictors()){ return result; } List predictors = predictorList.getPredictors(); for(Predictor predictor : predictors){ result.put(predictor.getName(), predictor); } return result; } static private Map> parsePPMatrix(final GeneralRegressionModel generalRegressionModel){ Function, Row> function = new Function, Row>(){ private BiMap factors = CacheUtil.getValue(generalRegressionModel, GeneralRegressionModelEvaluator.factorCache); private BiMap covariates = CacheUtil.getValue(generalRegressionModel, GeneralRegressionModelEvaluator.covariateCache); @Override public Row apply(List ppCells){ Row result = new Row(); ppCells: for(PPCell ppCell : ppCells){ FieldName name = ppCell.getPredictorName(); Predictor factor = this.factors.get(name); if(factor != null){ result.addFactor(ppCell, factor); continue ppCells; } Predictor covariate = this.covariates.get(name); if(covariate != null){ result.addCovariate(ppCell); continue ppCells; } throw new InvalidFeatureException(ppCell); } return result; } }; PPMatrix ppMatrix = generalRegressionModel.getPPMatrix(); ListMultimap targetCategoryMap = groupByTargetCategory(ppMatrix.getPPCells()); Map> result = new LinkedHashMap<>(); Collection>> targetCategoryEntries = (asMap(targetCategoryMap)).entrySet(); for(Map.Entry> targetCategoryEntry : targetCategoryEntries){ Map predictorMap = new LinkedHashMap<>(); ListMultimap parameterNameMap = groupByParameterName(targetCategoryEntry.getValue()); Collection>> parameterNameEntries = (asMap(parameterNameMap)).entrySet(); for(Map.Entry> parameterNameEntry : parameterNameEntries){ predictorMap.put(parameterNameEntry.getKey(), function.apply(parameterNameEntry.getValue())); } result.put(targetCategoryEntry.getKey(), predictorMap); } return result; } static private Map> parseParamMatrix(GeneralRegressionModel generalRegressionModel){ ParamMatrix paramMatrix = generalRegressionModel.getParamMatrix(); ListMultimap targetCategoryCells = groupByTargetCategory(paramMatrix.getPCells()); return asMap(targetCategoryCells); } @SuppressWarnings ( value = {"rawtypes", "unchecked"} ) static private Map> asMap(ListMultimap multimap){ return (Map)multimap.asMap(); } static private ListMultimap groupByParameterName(List cells){ Function function = new Function(){ @Override public String apply(C cell){ return cell.getParameterName(); } }; return groupCells(cells, function); } static private ListMultimap groupByTargetCategory(List cells){ Function function = new Function(){ @Override public String apply(C cell){ return cell.getTargetCategory(); } }; return groupCells(cells, function); } static private ListMultimap groupCells(List cells, Function function){ ListMultimap result = ArrayListMultimap.create(); for(C cell : cells){ result.put(function.apply(cell), cell); } return result; } static private class Row { private List factorHandlers = new ArrayList<>(); private List covariateHandlers = new ArrayList<>(); public Double evaluate(EvaluationContext context){ List factorHandlers = getFactorHandlers(); List covariateHandlers = getCovariateHandlers(); // The row is empty if(factorHandlers.isEmpty() && covariateHandlers.isEmpty()){ return Values.DOUBLE_ONE; } Double factorProduct = computeProduct(factorHandlers, context); Double covariateProduct = computeProduct(covariateHandlers, context); if(covariateHandlers.isEmpty()){ return factorProduct; } else if(factorHandlers.isEmpty()){ return covariateProduct; } else { if(factorProduct != null && covariateProduct != null){ return (factorProduct * covariateProduct); } return null; } } public void addFactor(PPCell ppCell, Predictor predictor){ List factorHandlers = getFactorHandlers(); Matrix matrix = predictor.getMatrix(); if(matrix != null){ Categories categories = predictor.getCategories(); if(categories == null){ throw new UnsupportedFeatureException(predictor); } Function function = new Function(){ @Override public String apply(Category category){ return category.getValue(); } }; List values = Lists.transform(categories.getCategories(), function); factorHandlers.add(new ContrastMatrixHandler(ppCell, matrix, values)); } else { factorHandlers.add(new FactorHandler(ppCell)); } } private void addCovariate(PPCell ppCell){ List covariateHandlers = getCovariateHandlers(); covariateHandlers.add(new CovariateHandler(ppCell)); } public List getFactorHandlers(){ return this.factorHandlers; } public List getCovariateHandlers(){ return this.covariateHandlers; } static private Double computeProduct(List predictorHandlers, EvaluationContext context){ if(predictorHandlers.isEmpty()){ return null; } double product = 1d; for(int i = 0, max = predictorHandlers.size(); i < max; i++){ PredictorHandler predictorHandler = predictorHandlers.get(i); FieldValue value = context.evaluate(predictorHandler.getPredictorName()); if(value == null){ return null; } // End if if(max == 1){ return predictorHandler.evaluate(value); } product = product * predictorHandler.evaluate(value); } return product; } abstract private class PredictorHandler { private PPCell ppCell = null; private PredictorHandler(PPCell ppCell){ setPPCell(ppCell); } abstract public Double evaluate(FieldValue value); public FieldName getPredictorName(){ PPCell ppCell = getPPCell(); return ppCell.getPredictorName(); } public PPCell getPPCell(){ return this.ppCell; } private void setPPCell(PPCell ppCell){ this.ppCell = ppCell; } } private class FactorHandler extends PredictorHandler { private FactorHandler(PPCell ppCell){ super(ppCell); } @Override public Double evaluate(FieldValue value){ PPCell ppCell = getPPCell(); boolean equals = value.equals(ppCell); return (equals ? Values.DOUBLE_ONE : Values.DOUBLE_ZERO); } public String getCategory(){ PPCell ppCell = getPPCell(); return ppCell.getValue(); } } private class ContrastMatrixHandler extends FactorHandler { private Matrix matrix = null; private List categories = null; private List parsedValueList = null; private ContrastMatrixHandler(PPCell ppCell, Matrix matrix, List categories){ super(ppCell); setMatrix(matrix); setCategories(categories); } @Override public Double evaluate(FieldValue value){ Matrix matrix = getMatrix(); int row = getIndex(value); int column = getIndex(getCategory()); if(row < 0 || column < 0){ throw new EvaluationException(); } Number result = MatrixUtil.getElementAt(matrix, row + 1, column + 1); if(result == null){ throw new EvaluationException(); } // End if if(result instanceof Double){ return (Double)result; } return result.doubleValue(); } public int getIndex(FieldValue value){ if(this.parsedValueList == null){ this.parsedValueList = ImmutableList.copyOf(parseCategories(value.getDataType(), value.getOpType())); } return this.parsedValueList.indexOf(value); } public int getIndex(String category){ List categories = getCategories(); return categories.indexOf(category); } private List parseCategories(final DataType dataType, final OpType opType){ List categories = getCategories(); Function function = new Function(){ @Override public FieldValue apply(String value){ return FieldValueUtil.create(dataType, opType, value); } }; return Lists.transform(categories, function); } public Matrix getMatrix(){ return this.matrix; } private void setMatrix(Matrix matrix){ this.matrix = matrix; } public List getCategories(){ return this.categories; } private void setCategories(List categories){ this.categories = categories; } } private class CovariateHandler extends PredictorHandler { private CovariateHandler(PPCell ppCell){ super(ppCell); } @Override public Double evaluate(FieldValue value){ double multiplicity = getMultiplicity(); if(multiplicity == 1d){ return value.asDouble(); } return Math.pow((value.asNumber()).doubleValue(), multiplicity); } public double getMultiplicity(){ PPCell ppCell = getPPCell(); String value = ppCell.getValue(); if(("1").equals(value) || ("1.0").equals(value)){ return 1d; } return Double.parseDouble(value); } } } private static final LoadingCache> parameterCache = CacheUtil.buildLoadingCache(new CacheLoader>(){ @Override public BiMap load(GeneralRegressionModel generalRegressionModel){ return ImmutableBiMap.copyOf(parseParameterRegistry(generalRegressionModel.getParameterList())); } }); private static final LoadingCache> factorCache = CacheUtil.buildLoadingCache(new CacheLoader>(){ @Override public BiMap load(GeneralRegressionModel generalRegressionModel){ return ImmutableBiMap.copyOf(parsePredictorRegistry(generalRegressionModel.getFactorList())); } }); private static final LoadingCache> covariateCache = CacheUtil.buildLoadingCache(new CacheLoader>(){ @Override public BiMap load(GeneralRegressionModel generalRegressionModel){ return ImmutableBiMap.copyOf(parsePredictorRegistry(generalRegressionModel.getCovariateList())); } }); private static final LoadingCache>> ppMatrixCache = CacheUtil.buildLoadingCache(new CacheLoader>>(){ @Override public Map> load(GeneralRegressionModel generalRegressionModel){ // Cannot use Guava's ImmutableMap, because it is null-hostile return Collections.unmodifiableMap(parsePPMatrix(generalRegressionModel)); } }); private static final LoadingCache>> paramMatrixCache = CacheUtil.buildLoadingCache(new CacheLoader>>(){ @Override public Map> load(GeneralRegressionModel generalRegressionModel){ // Cannot use Guava's ImmutableMap, because it is null-hostile return Collections.unmodifiableMap(parseParamMatrix(generalRegressionModel)); } }); }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy