All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.ClusteringModelEvaluator Maven / Gradle / Ivy

/*
 * Copyright (c) 2013 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see .
 */
package org.jpmml.evaluator;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableList;
import org.dmg.pmml.Array;
import org.dmg.pmml.CenterFields;
import org.dmg.pmml.Cluster;
import org.dmg.pmml.ClusteringField;
import org.dmg.pmml.ClusteringModel;
import org.dmg.pmml.ComparisonMeasure;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Measure;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MissingValueWeights;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Target;

public class ClusteringModelEvaluator extends ModelEvaluator implements HasEntityRegistry {

	public ClusteringModelEvaluator(PMML pmml){
		super(pmml, ClusteringModel.class);
	}

	public ClusteringModelEvaluator(PMML pmml, ClusteringModel clusteringModel){
		super(pmml, clusteringModel);
	}

	@Override
	public String getSummary(){
		return "Clustering model";
	}

	/**
	 * @return null Always.
	 */
	@Override
	public Target getTarget(FieldName name){
		return null;
	}

	@Override
	public BiMap getEntityRegistry(){
		return getValue(ClusteringModelEvaluator.entityCache);
	}

	@Override
	public Map evaluate(ModelEvaluationContext context){
		ClusteringModel clusteringModel = getModel();
		if(!clusteringModel.isScorable()){
			throw new InvalidResultException(clusteringModel);
		}

		Map predictions;

		MiningFunctionType miningFunction = clusteringModel.getFunctionName();
		switch(miningFunction){
			case CLUSTERING:
				predictions = evaluateClustering(context);
				break;
			default:
				throw new UnsupportedFeatureException(clusteringModel, miningFunction);
		}

		return OutputUtil.evaluate(predictions, context);
	}

	private Map evaluateClustering(EvaluationContext context){
		ClusteringModel clusteringModel = getModel();

		ClusteringModel.ModelClass modelClass = clusteringModel.getModelClass();
		switch(modelClass){
			case CENTER_BASED:
				break;
			default:
				throw new UnsupportedFeatureException(clusteringModel, modelClass);
		}

		CenterFields centerFields = clusteringModel.getCenterFields();
		if(centerFields != null){
			throw new UnsupportedFeatureException(centerFields);
		}

		List values = new ArrayList<>();

		List clusteringFields = getCenterClusteringFields();
		for(ClusteringField clusteringField : clusteringFields){
			FieldValue value = context.evaluate(clusteringField.getField());

			values.add(value);
		}

		ClusterAffinityDistribution result;

		ComparisonMeasure comparisonMeasure = clusteringModel.getComparisonMeasure();

		Measure measure = comparisonMeasure.getMeasure();

		if(MeasureUtil.isSimilarity(measure)){
			result = evaluateSimilarity(comparisonMeasure, clusteringFields, values);
		} else

		if(MeasureUtil.isDistance(measure)){
			result = evaluateDistance(comparisonMeasure, clusteringFields, values);
		} else

		{
			throw new UnsupportedFeatureException(measure);
		}

		// "For clustering models, the identifier of the winning cluster is returned as the predictedValue"
		result.computeResult(DataType.STRING);

		return Collections.singletonMap(getTargetField(), result);
	}

	private ClusterAffinityDistribution evaluateSimilarity(ComparisonMeasure comparisonMeasure, List clusteringFields, List values){
		ClusteringModel clusteringModel = getModel();

		BiMap entityRegistry = getEntityRegistry();

		ClusterAffinityDistribution result = new ClusterAffinityDistribution(Classification.Type.SIMILARITY, entityRegistry);

		BitSet flags = MeasureUtil.toBitSet(values);

		List clusters = clusteringModel.getClusters();
		for(Cluster cluster : clusters){
			BitSet clusterFlags = CacheUtil.getValue(cluster, ClusteringModelEvaluator.clusterFlagCache);

			if(flags.size() != clusterFlags.size()){
				throw new InvalidFeatureException(cluster);
			}

			String id = EntityUtil.getId(cluster, entityRegistry);

			Double similarity = MeasureUtil.evaluateSimilarity(comparisonMeasure, clusteringFields, flags, clusterFlags);

			result.put(cluster, id, similarity);
		}

		return result;
	}

	private ClusterAffinityDistribution evaluateDistance(ComparisonMeasure comparisonMeasure, List clusteringFields, List values){
		ClusteringModel clusteringModel = getModel();

		BiMap entityRegistry = getEntityRegistry();

		ClusterAffinityDistribution result = new ClusterAffinityDistribution(Classification.Type.DISTANCE, entityRegistry);

		Double adjustment;

		MissingValueWeights missingValueWeights = clusteringModel.getMissingValueWeights();
		if(missingValueWeights != null){
			Array array = missingValueWeights.getArray();

			List adjustmentValues = ArrayUtil.asNumberList(array);
			if(values.size() != adjustmentValues.size()){
				throw new InvalidFeatureException(missingValueWeights);
			}

			adjustment = MeasureUtil.calculateAdjustment(values, adjustmentValues);
		} else

		{
			adjustment = MeasureUtil.calculateAdjustment(values);
		}

		List clusters = clusteringModel.getClusters();
		for(Cluster cluster : clusters){
			List clusterValues = CacheUtil.getValue(cluster, ClusteringModelEvaluator.clusterValueCache);

			if(values.size() != clusterValues.size()){
				throw new InvalidFeatureException(cluster);
			}

			String id = EntityUtil.getId(cluster, entityRegistry);

			Double distance = MeasureUtil.evaluateDistance(comparisonMeasure, clusteringFields, values, clusterValues, adjustment);

			result.put(cluster, id, distance);
		}

		return result;
	}

	private List getCenterClusteringFields(){
		ClusteringModel clusteringModel = getModel();

		List result = new ArrayList<>();

		List clusteringFields = clusteringModel.getClusteringFields();
		for(ClusteringField clusteringField : clusteringFields){
			ClusteringField.CenterField centerField = clusteringField.getCenterField();

			switch(centerField){
				case TRUE:
					result.add(clusteringField);
					break;
				case FALSE:
					break;
				default:
					throw new UnsupportedFeatureException(clusteringField, centerField);
			}
		}

		return result;
	}

	private static final LoadingCache> clusterValueCache = CacheUtil.buildLoadingCache(new CacheLoader>(){

		@Override
		public List load(Cluster cluster){
			Array array = cluster.getArray();

			List values = ArrayUtil.asNumberList(array);

			return ImmutableList.copyOf(FieldValueUtil.createAll(values));
		}
	});

	private static final LoadingCache clusterFlagCache = CacheUtil.buildLoadingCache(new CacheLoader(){

		@Override
		public BitSet load(Cluster cluster){
			List values = CacheUtil.getValue(cluster, ClusteringModelEvaluator.clusterValueCache);

			return MeasureUtil.toBitSet(values);
		}
	});

	private static final LoadingCache> entityCache = CacheUtil.buildLoadingCache(new CacheLoader>(){

		@Override
		public BiMap load(ClusteringModel clusteringModel){
			return EntityUtil.buildBiMap(clusteringModel.getClusters());
		}
	});
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy