All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.evaluator.ModelEvaluator Maven / Gradle / Ivy

There is a newer version: 1.6.5
Show newest version
/*
 * Copyright (c) 2016 Villu Ruusmann
 *
 * This file is part of JPMML-Evaluator
 *
 * JPMML-Evaluator is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Evaluator is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Evaluator.  If not, see .
 */
package org.jpmml.evaluator;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Callable;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import com.google.common.collect.Sets.SetView;
import com.google.common.collect.Table;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.LocalTransformations;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.ModelVerification;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.Target;
import org.dmg.pmml.Targets;
import org.dmg.pmml.TransformationDictionary;
import org.dmg.pmml.TypeDefinitionField;
import org.dmg.pmml.VerificationField;
import org.dmg.pmml.VerificationFields;

abstract
public class ModelEvaluator implements Evaluator, Serializable {

	private PMML pmml = null;

	private M model = null;

	private Map dataFields = Collections.emptyMap();

	private Map derivedFields = Collections.emptyMap();

	private Map defineFunctions = Collections.emptyMap();

	private Map miningFields = Collections.emptyMap();

	transient
	private List inputFields = null;

	transient
	private List activeInputFields = null;

	private Map localDerivedFields = Collections.emptyMap();

	private Map targets = Collections.emptyMap();

	transient
	private List targetResultFields = null;

	private Map outputFields = Collections.emptyMap();

	transient
	private List outputResultFields = null;


	protected ModelEvaluator(PMML pmml, M model){
		setPMML(Objects.requireNonNull(pmml));
		setModel(Objects.requireNonNull(model));

		DataDictionary dataDictionary = pmml.getDataDictionary();
		if(dataDictionary == null){
			throw new InvalidFeatureException(pmml);
		} // End if

		if(dataDictionary.hasDataFields()){
			this.dataFields = CacheUtil.getValue(dataDictionary, ModelEvaluator.dataFieldCache);
		}

		TransformationDictionary transformationDictionary = pmml.getTransformationDictionary();
		if(transformationDictionary != null && transformationDictionary.hasDerivedFields()){
			this.derivedFields = CacheUtil.getValue(transformationDictionary, ModelEvaluator.derivedFieldCache);
		} // End if

		if(transformationDictionary != null && transformationDictionary.hasDefineFunctions()){
			this.defineFunctions = CacheUtil.getValue(transformationDictionary, ModelEvaluator.defineFunctionCache);
		}

		MiningSchema miningSchema = model.getMiningSchema();
		if(miningSchema == null){
			throw new InvalidFeatureException(model);
		} // End if

		if(miningSchema.hasMiningFields()){
			this.miningFields = CacheUtil.getValue(miningSchema, ModelEvaluator.miningFieldCache);
		}

		LocalTransformations localTransformations = model.getLocalTransformations();
		if(localTransformations != null && localTransformations.hasDerivedFields()){
			this.localDerivedFields = CacheUtil.getValue(localTransformations, ModelEvaluator.localDerivedFieldCache);
		}

		Targets targets = model.getTargets();
		if(targets != null && targets.hasTargets()){
			this.targets = CacheUtil.getValue(targets, ModelEvaluator.targetCache);
		}

		Output output = model.getOutput();
		if(output != null && output.hasOutputFields()){
			this.outputFields = CacheUtil.getValue(output, ModelEvaluator.outputFieldCache);
		}
	}

	abstract
	public Map evaluate(ModelEvaluationContext context);

	@Override
	public MiningFunction getMiningFunction(){
		M model = getModel();

		return model.getMiningFunction();
	}

	public DataField getDataField(FieldName name){

		if(Objects.equals(Evaluator.DEFAULT_TARGET_NAME, name)){
			return getDataField();
		}

		return this.dataFields.get(name);
	}

	/**
	 * @return A synthetic {@link DataField} element describing the default target field.
	 */
	protected DataField getDataField(){
		MiningFunction miningFunction = getMiningFunction();

		switch(miningFunction){
			case REGRESSION:
				return ModelEvaluator.DEFAULT_REGRESSION_TARGET;
			case CLASSIFICATION:
				return ModelEvaluator.DEFAULT_CLASSIFICATION_TARGET;
			case CLUSTERING:
				return ModelEvaluator.DEFAULT_CLUSTERING_TARGET;
			default:
				break;
		}

		return null;
	}

	public DerivedField getDerivedField(FieldName name){
		return this.derivedFields.get(name);
	}

	public DefineFunction getDefineFunction(String name){
		return this.defineFunctions.get(name);
	}

	public MiningField getMiningField(FieldName name){

		if(Objects.equals(Evaluator.DEFAULT_TARGET_NAME, name)){
			return null;
		}

		return this.miningFields.get(name);
	}

	@Override
	public List getInputFields(){

		if(this.inputFields == null){
			this.inputFields = createInputFields();
		}

		return this.inputFields;
	}

	@Override
	public List getActiveFields(){

		if(this.activeInputFields == null){
			this.activeInputFields = createInputFields(MiningField.UsageType.ACTIVE);
		}

		return this.activeInputFields;
	}

	public DerivedField getLocalDerivedField(FieldName name){
		return this.localDerivedFields.get(name);
	}

	public Target getTarget(FieldName name){
		return this.targets.get(name);
	}

	@Override
	public List getTargetFields(){

		if(this.targetResultFields == null){
			this.targetResultFields = createTargetFields();
		}

		return this.targetResultFields;
	}

	public TargetField getTargetField(){
		List targetFields = getTargetFields();

		if(targetFields.size() != 1){
			throw new EvaluationException();
		}

		TargetField targetField = targetFields.get(0);

		return targetField;
	}

	public FieldName getTargetFieldName(){
		TargetField targetField = getTargetField();

		return targetField.getName();
	}

	public org.dmg.pmml.OutputField getOutputField(FieldName name){
		return this.outputFields.get(name);
	}

	@Override
	public List getOutputFields(){

		if(this.outputResultFields == null){
			this.outputResultFields = createOutputFields();
		}

		return this.outputResultFields;
	}

	@Override
	public void verify(){
		M model = getModel();

		ModelVerification modelVerification = model.getModelVerification();
		if(modelVerification == null){
			return;
		}

		VerificationBatch batch = CacheUtil.getValue(modelVerification, ModelEvaluator.batchCache);

		List> records = batch.getRecords();

		List inputFields = getInputFields();

		if(this instanceof HasGroupFields){
			HasGroupFields hasGroupFields = (HasGroupFields)this;

			records = EvaluatorUtil.groupRows(hasGroupFields, records);
		}

		List targetFields = getTargetFields();
		List outputFields = getOutputFields();

		SetView intersection = Sets.intersection(batch.keySet(), new LinkedHashSet<>(EvaluatorUtil.getNames(outputFields)));

		for(Map record : records){
			Map arguments = new LinkedHashMap<>();

			for(InputField inputField : inputFields){
				FieldName name = inputField.getName();

				FieldValue value = EvaluatorUtil.prepare(inputField, record.get(name));

				arguments.put(name, value);
			}

			Map result = evaluate(arguments);

			// "If there exist VerificationField elements that refer to OutputField elements,
			// then any VerificationField element that refers to a MiningField element whose "usageType=target" should be ignored,
			// because they are considered to represent a dependent variable from the training data set, not an expected output"
			if(intersection.size() > 0){

				for(OutputField outputField : outputFields){
					FieldName name = outputField.getName();

					VerificationField verificationField = batch.get(name);
					if(verificationField == null){
						continue;
					}

					verify(record.get(name), result.get(name), verificationField.getPrecision(), verificationField.getZeroThreshold());
				}
			} else

			// "If there are no such VerificationField elements,
			// then any VerificationField element that refers to a MiningField element whose "usageType=target" should be considered to represent an expected output"
			{
				for(TargetField targetField : targetFields){
					FieldName name = targetField.getName();

					VerificationField verificationField = batch.get(name);
					if(verificationField == null){
						continue;
					}

					verify(record.get(name), EvaluatorUtil.decode(result.get(name)), verificationField.getPrecision(), verificationField.getZeroThreshold());
				}
			}
		}
	}

	/**
	 * @param expected A string or a collection of strings representing the expected value
	 * @param actual The actual value
	 */
	private void verify(Object expected, Object actual, double precision, double zeroThreshold){

		if(expected == null){
			return;
		} // End if

		if(!(actual instanceof Collection)){
			DataType dataType = TypeUtil.getDataType(actual);

			expected = TypeUtil.parseOrCast(dataType, expected);
		}

		boolean acceptable = VerificationUtil.acceptable(expected, actual, precision, zeroThreshold);
		if(!acceptable){
			throw new EvaluationException();
		}
	}

	@Override
	public Map evaluate(Map arguments){
		ModelEvaluationContext context = new ModelEvaluationContext(this);
		context.setArguments(arguments);

		return evaluate(context);
	}

	protected TypeDefinitionField resolveField(FieldName name){
		TypeDefinitionField result = getDataField(name);

		if(result == null){
			result = resolveDerivedField(name);
		}

		return result;
	}

	protected DerivedField resolveDerivedField(FieldName name){
		DerivedField result = getDerivedField(name);

		if(result == null){
			result = getLocalDerivedField(name);
		}

		return result;
	}

	protected List createInputFields(){
		List inputFields = getActiveFields();

		List outputFields = getOutputFields();
		if(outputFields.size() > 0){
			List targetReferenceFields = null;

			for(OutputField outputField : outputFields){
				org.dmg.pmml.OutputField pmmlOutputField = outputField.getOutputField();
;
				if(!(pmmlOutputField.getResultFeature()).equals(ResultFeature.RESIDUAL)){
					continue;
				}

				int depth = outputField.getDepth();
				if(depth > 0){
					throw new UnsupportedFeatureException(pmmlOutputField);
				}

				FieldName targetFieldName = pmmlOutputField.getTargetField();
				if(targetFieldName == null){
					targetFieldName = getTargetFieldName();
				}

				DataField dataField = getDataField(targetFieldName);
				if(dataField == null){
					throw new MissingFieldException(targetFieldName, pmmlOutputField);
				}

				MiningField miningField = getMiningField(targetFieldName);
				if(miningField == null){
					throw new EvaluationException();
				}

				Target target = getTarget(targetFieldName);

				TargetReferenceField targetReferenceField = new TargetReferenceField(dataField, miningField, target);

				if(targetReferenceFields == null){
					targetReferenceFields = new ArrayList<>();
				}

				targetReferenceFields.add(targetReferenceField);
			}

			if(targetReferenceFields != null && targetReferenceFields.size() > 0){
				inputFields = ImmutableList.copyOf(Iterables.concat(inputFields, targetReferenceFields));
			}
		}

		return inputFields;
	}

	protected List createInputFields(MiningField.UsageType usageType){
		M model = getModel();

		MiningSchema miningSchema = model.getMiningSchema();

		List inputFields = new ArrayList<>();

		if(miningSchema.hasMiningFields()){
			List miningFields = miningSchema.getMiningFields();

			for(MiningField miningField : miningFields){
				FieldName name = miningField.getName();

				if(!(miningField.getUsageType()).equals(usageType)){
					continue;
				}

				Field field = getDataField(name);
				if(field == null){
					field = new VariableField(name);
				}

				InputField inputField = new InputField(field, miningField);

				inputFields.add(inputField);
			}
		}

		return ImmutableList.copyOf(inputFields);
	}

	protected List createTargetFields(){
		M model = getModel();

		MiningSchema miningSchema = model.getMiningSchema();

		List targetFields = new ArrayList<>();

		if(miningSchema.hasMiningFields()){
			List miningFields = miningSchema.getMiningFields();

			for(MiningField miningField : miningFields){
				FieldName name = miningField.getName();

				MiningField.UsageType usageType = miningField.getUsageType();
				switch(usageType){
					case TARGET:
					case PREDICTED:
						break;
					default:
						continue;
				}

				DataField dataField = getDataField(name);
				if(dataField == null){
					throw new MissingFieldException(name, miningField);
				}

				Target target = getTarget(name);

				TargetField targetField = new TargetField(dataField, miningField, target);

				targetFields.add(targetField);
			}
		}

		synthesis:
		if(targetFields.isEmpty()){
			DataField dataField = getDataField();

			if(dataField == null){
				break synthesis;
			}

			Target target = getTarget(dataField.getName());

			TargetField targetField = new TargetField(dataField, null, target);

			targetFields.add(targetField);
		}

		return ImmutableList.copyOf(targetFields);
	}

	protected List createOutputFields(){
		M model = getModel();

		Output output = model.getOutput();

		List resultFields = new ArrayList<>();

		if(output != null && output.hasOutputFields()){
			List outputFields = output.getOutputFields();

			for(org.dmg.pmml.OutputField outputField : outputFields){
				OutputField resultField = new OutputField(outputField);

				resultFields.add(resultField);
			}
		}

		return ImmutableList.copyOf(resultFields);
	}

	public  V getValue(LoadingCache cache){
		M model = getModel();

		return CacheUtil.getValue(model, cache);
	}

	public  V getValue(Cache cache, Callable loader){
		M model = getModel();

		return CacheUtil.getValue(model, cache, loader);
	}

	public PMML getPMML(){
		return this.pmml;
	}

	private void setPMML(PMML pmml){
		this.pmml = pmml;
	}

	public M getModel(){
		return this.model;
	}

	private void setModel(M model){
		this.model = model;
	}

	static
	protected  M selectModel(PMML pmml, Class clazz){

		if(!pmml.hasModels()){
			throw new InvalidFeatureException(pmml);
		}

		List models = pmml.getModels();

		Iterable filteredModels = Iterables.filter(models, clazz);

		M model = Iterables.getFirst(filteredModels, null);
		if(model == null){
			throw new InvalidFeatureException(pmml);
		}

		return model;
	}

	static
	private VerificationBatch parseModelVerification(ModelVerification modelVerification){
		VerificationBatch result = new VerificationBatch();

		VerificationFields verificationFields = modelVerification.getVerificationFields();
		if(verificationFields == null){
			throw new InvalidFeatureException(modelVerification);
		}

		for(VerificationField verificationField : verificationFields){
			result.put(verificationField.getField(), verificationField);
		}

		InlineTable inlineTable = modelVerification.getInlineTable();
		if(inlineTable == null){
			throw new InvalidFeatureException(modelVerification);
		}

		Table table = InlineTableUtil.getContent(inlineTable);

		List> records = new ArrayList<>();

		Set rowKeys = table.rowKeySet();
		for(Integer rowKey : rowKeys){
			Map row = table.row(rowKey);

			Map record = new LinkedHashMap<>();

			for(VerificationField verificationField : verificationFields){
				FieldName name = verificationField.getField();
				String column = verificationField.getColumn();

				if(column == null){
					column = name.getValue();
				} // End if

				if(!row.containsKey(column)){
					continue;
				}

				record.put(name, row.get(column));
			}

			records.add(record);
		}

		Integer recordCount = modelVerification.getRecordCount();
		if(recordCount != null && recordCount.intValue() != records.size()){
			throw new InvalidFeatureException(inlineTable);
		}

		result.setRecords(records);

		return result;
	}

	private static final DataField DEFAULT_REGRESSION_TARGET = new DataField(Evaluator.DEFAULT_TARGET_NAME, OpType.CONTINUOUS, DataType.DOUBLE);
	private static final DataField DEFAULT_CLASSIFICATION_TARGET = new DataField(Evaluator.DEFAULT_TARGET_NAME, OpType.CATEGORICAL, DataType.STRING);
	private static final DataField DEFAULT_CLUSTERING_TARGET = new DataField(Evaluator.DEFAULT_TARGET_NAME, OpType.CATEGORICAL, DataType.STRING);

	private static final LoadingCache> dataFieldCache = CacheUtil.buildLoadingCache(new CacheLoader>(){

		@Override
		public Map load(DataDictionary dataDictionary){
			return IndexableUtil.buildMap(dataDictionary.getDataFields());
		}
	});

	private static final LoadingCache> derivedFieldCache = CacheUtil.buildLoadingCache(new CacheLoader>(){

		@Override
		public Map load(TransformationDictionary transformationDictionary){
			return IndexableUtil.buildMap(transformationDictionary.getDerivedFields());
		}
	});

	private static final LoadingCache> defineFunctionCache = CacheUtil.buildLoadingCache(new CacheLoader>(){

		@Override
		public Map load(TransformationDictionary transformationDictionary){
			return IndexableUtil.buildMap(transformationDictionary.getDefineFunctions());
		}
	});

	private static final LoadingCache> miningFieldCache = CacheUtil.buildLoadingCache(new CacheLoader>(){

		@Override
		public Map load(MiningSchema miningSchema){
			return IndexableUtil.buildMap(miningSchema.getMiningFields());
		}
	});

	private static final LoadingCache> localDerivedFieldCache = CacheUtil.buildLoadingCache(new CacheLoader>(){

		@Override
		public Map load(LocalTransformations localTransformations){
			return IndexableUtil.buildMap(localTransformations.getDerivedFields());
		}
	});

	private static final LoadingCache> targetCache = CacheUtil.buildLoadingCache(new CacheLoader>(){

		@Override
		public Map load(Targets targets){
			return IndexableUtil.buildMap(targets.getTargets(), true);
		}
	});

	private static final LoadingCache> outputFieldCache = CacheUtil.buildLoadingCache(new CacheLoader>(){

		@Override
		public Map load(Output output){
			return IndexableUtil.buildMap(output.getOutputFields());
		}
	});

	static
	private class VerificationBatch extends LinkedHashMap {

		private List> records = null;


		public List> getRecords(){
			return this.records;
		}

		private void setRecords(List> records){
			this.records = records;
		}
	}

	private static final LoadingCache batchCache = CacheUtil.buildLoadingCache(new CacheLoader(){

		@Override
		public VerificationBatch load(ModelVerification modelVerification){
			return parseModelVerification(modelVerification);
		}
	});
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy