All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.manager.ModelManager Maven / Gradle / Ivy

/*
 * Copyright (c) 2009 University of Tartu
 */
package org.jpmml.manager;

import java.util.*;

import org.dmg.pmml.*;

import com.google.common.collect.*;

abstract
public class ModelManager extends PMMLManager implements Consumer {

	public ModelManager(){
	}

	public ModelManager(PMML pmml){
		super(pmml);
	}

	abstract
	public M getModel();

	/**
	 * Convenience method for adding a field declaration to {@link DataDictionary} and {@link MiningSchema}.
	 *
	 * @see #addDataField(FieldName, String, OpType, DataType)
	 * @see #addMiningField(FieldName, FieldUsageType)
	 */
	public void addField(FieldName name, String displayName, OpType opType, DataType dataType, FieldUsageType fieldUsageType){
		addDataField(name, displayName, opType, dataType);
		addMiningField(name, fieldUsageType);
	}

	@Override
	public List getActiveFields(){
		return getMiningFields(FieldUsageType.ACTIVE);
	}

	@Override
	public List getGroupFields(){
		return getMiningFields(FieldUsageType.GROUP);
	}

	@Override
	public FieldName getTargetField(){
		List predictedFields = getPredictedFields();

		// "The definition of predicted fields in the MiningSchema is not required"
		if(predictedFields.size() < 1){
			return null;
		} else

		if(predictedFields.size() > 1){
			throw new InvalidFeatureException("Too many predicted fields", getMiningSchema());
		}

		return predictedFields.get(0);
	}

	@Override
	public List getPredictedFields(){
		return getMiningFields(FieldUsageType.PREDICTED);
	}

	public List getMiningFields(FieldUsageType fieldUsageType){
		List result = Lists.newArrayList();

		List miningFields = getMiningSchema().getMiningFields();
		for(MiningField miningField : miningFields){

			if((miningField.getUsageType()).equals(fieldUsageType)){
				result.add(miningField.getName());
			}
		}

		return result;
	}

	@Override
	public MiningField getMiningField(FieldName name){
		List miningFields = getMiningSchema().getMiningFields();

		return find(miningFields, name);
	}

	public MiningField addMiningField(FieldName name, FieldUsageType usageType){
		MiningField miningField = new MiningField(name);
		miningField.setUsageType(usageType);

		List miningFields = getMiningSchema().getMiningFields();
		miningFields.add(miningField);

		return miningField;
	}

	@Override
	public List getOutputFields(){
		List result = Lists.newArrayList();

		Output output = getOrCreateOutput();

		List outputFields = output.getOutputFields();
		for(OutputField outputField : outputFields){
			result.add(outputField.getName());
		}

		return result;
	}

	@Override
	public OutputField getOutputField(FieldName name){
		Output output = getOrCreateOutput();

		List outputFields = output.getOutputFields();

		return find(outputFields, name);
	}

	@Override
	public DerivedField resolveField(FieldName name){
		LocalTransformations localTransformations = getOrCreateLocalTransformations();

		List derivedFields = localTransformations.getDerivedFields();

		DerivedField derivedField = find(derivedFields, name);
		if(derivedField == null){
			derivedField = super.resolveField(name);
		}

		return derivedField;
	}

	public Target getTarget(FieldName name){
		Targets targetDictionary = getOrCreateTargets();

		List targets = targetDictionary.getTargets();
		for(Target target : targets){

			if((target.getField()).equals(name)){
				return target;
			}
		}

		return null;
	}

	public MiningSchema getMiningSchema(){
		M model = getModel();

		return model.getMiningSchema();
	}

	public LocalTransformations getOrCreateLocalTransformations(){
		M model = getModel();

		LocalTransformations localTransformations = model.getLocalTransformations();
		if(localTransformations == null){
			localTransformations = new LocalTransformations();

			model.setLocalTransformations(localTransformations);
		}

		return localTransformations;
	}

	public Output getOrCreateOutput(){
		M model = getModel();

		Output output = model.getOutput();
		if(output == null){
			output = new Output();

			model.setOutput(output);
		}

		return output;
	}

	public Targets getOrCreateTargets(){
		M model = getModel();

		Targets targets = model.getTargets();
		if(targets == null){
			targets = new Targets();

			model.setTargets(targets);
		}

		return targets;
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy