All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jpmml.converter.visitors.AbstractTreeModelTransformer Maven / Gradle / Ivy

There is a newer version: 1.5.12
Show newest version
/*
 * Copyright (c) 2018 Villu Ruusmann
 *
 * This file is part of JPMML-Converter
 *
 * JPMML-Converter is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * JPMML-Converter is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with JPMML-Converter.  If not, see .
 */
package org.jpmml.converter.visitors;

import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;

import org.dmg.pmml.Array;
import org.dmg.pmml.HasFieldReference;
import org.dmg.pmml.HasValue;
import org.dmg.pmml.HasValueSet;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.SimpleSetPredicate;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.model.UnsupportedElementException;
import org.jpmml.model.visitors.AbstractVisitor;

abstract
public class AbstractTreeModelTransformer extends AbstractVisitor {

	@Override
	public void pushParent(PMMLObject object){
		super.pushParent(object);

		if(object instanceof Node){
			enterNode((Node)object);
		} else

		if(object instanceof TreeModel){
			enterTreeModel((TreeModel)object);
		}
	}

	@Override
	public PMMLObject popParent(){
		PMMLObject object = super.popParent();

		if(object instanceof Node){
			exitNode((Node)object);
		} else

		if(object instanceof TreeModel){
			exitTreeModel((TreeModel)object);
		}

		return object;
	}

	public void enterNode(Node node){
	}

	public void exitNode(Node node){
	}

	public void enterTreeModel(TreeModel treeModel){
	}

	public void exitTreeModel(TreeModel treeModel){
	}

	protected Node getParentNode(){
		Deque parents = getParents();

		PMMLObject parent = parents.peekFirst();

		if(parent instanceof Node){
			return (Node)parent;
		} else

		if(parent instanceof TreeModel){
			return null;
		} else

		{
			throw new IllegalStateException();
		}
	}

	public Node getAncestorNode(java.util.function.Predicate predicate){
		Deque parents = getParents();

		Iterator parentIt = parents.iterator();

		while(parentIt.hasNext()){
			PMMLObject parent = parentIt.next();

			if(parent instanceof Node){
				Node node = (Node)parent;

				if(predicate.test(node)){
					return node;
				}
			} else

			if(parent instanceof TreeModel){
				return null;
			} else

			{
				throw new IllegalStateException();
			}
		}

		return null;
	}

	protected TreeModel getParentTreeModel(){
		Deque parents = getParents();

		Iterator parentIt = parents.iterator();

		while(parentIt.hasNext()){
			PMMLObject parent = parentIt.next();

			if(parent instanceof Node){
				continue;
			} else

			if(parent instanceof TreeModel){
				return (TreeModel)parent;
			} else

			{
				throw new IllegalStateException();
			}
		}

		throw new IllegalStateException();
	}

	static
	protected List swapChildren(Node node){
		List children = node.getNodes();

		if(children.size() != 2){
			throw new UnsupportedElementException(node);
		}

		Node firstChild = children.remove(0);

		children.add(1, firstChild);

		return children;
	}

	static
	protected void initScore(Node parentNode, Node node){
		Object score = node.getScore();

		if(parentNode.hasScore()){
			throw new UnsupportedElementException(parentNode);
		}

		parentNode.setScore(score);
	}

	static
	protected void initScoreDistribution(Node parentNode, Node node){
		Object score = node.getScore();
		Number recordCount = node.getRecordCount();

		if(parentNode.hasScore()){
			throw new UnsupportedElementException(parentNode);
		} // End if

		Number parentRecordCount = parentNode.getRecordCount();
		if(parentRecordCount != null){
			throw new UnsupportedElementException(parentNode);
		} // End if

		if(parentNode.hasScoreDistributions()){
			throw new UnsupportedElementException(parentNode);
		}

		parentNode
			.setScore(score)
			.setRecordCount(recordCount);

		if(node.hasScoreDistributions()){
			List scoreDistributions = node.getScoreDistributions();

			List parentScoreDistributions = parentNode.getScoreDistributions();

			parentScoreDistributions.addAll(scoreDistributions);
		}
	}

	static
	protected void initDefaultChild(Node parentNode, Node node){
		Object defaultChild = node.getDefaultChild();

		Object parentDefaultChild = parentNode.getDefaultChild();
		if(parentDefaultChild != null){
			throw new UnsupportedElementException(parentNode);
		}

		parentNode.setDefaultChild(defaultChild);
	}

	static
	protected void replaceChildWithGrandchildren(Node parentNode, Node node){
		List parentChildren = parentNode.getNodes();

		int index = parentChildren.indexOf(node);
		if(index < 0 || index != (parentChildren.size() - 1)){
			throw new UnsupportedElementException(parentNode);
		}

		parentChildren.remove(index);

		if(node.hasNodes()){
			List children = node.getNodes();

			parentChildren.addAll(index, children);
		}
	}

	static
	protected boolean equalsNode(Object defaultChild, Node node){

		if(defaultChild instanceof Node){
			return Objects.equals(defaultChild, node);
		}

		return Objects.equals(defaultChild, node.getId());
	}

	static
	protected boolean hasFieldReference(Predicate predicate, String fieldName){

		if(predicate instanceof HasFieldReference){
			HasFieldReference hasFieldReference = (HasFieldReference)predicate;

			return Objects.equals(hasFieldReference.requireField(), fieldName);
		}

		return false;
	}

	static
	protected boolean hasValue(Predicate predicate, String value){

		if(predicate instanceof HasValue){
			HasValue hasValue = (HasValue)predicate;

			return Objects.equals(hasValue.getValue(), value);
		}

		return false;
	}

	static
	protected boolean hasOperator(Predicate predicate, SimplePredicate.Operator operator){

		if(predicate instanceof SimplePredicate){
			SimplePredicate simplePredicate = (SimplePredicate)predicate;

			return (operator == simplePredicate.requireOperator());
		}

		return false;
	}

	static
	protected boolean hasBooleanOperator(Predicate predicate, SimpleSetPredicate.BooleanOperator booleanOperator){

		if(predicate instanceof SimpleSetPredicate){
			SimpleSetPredicate simpleSetPredicate = (SimpleSetPredicate)predicate;

			return (booleanOperator == simpleSetPredicate.requireBooleanOperator());
		}

		return false;
	}

	static
	protected void checkFieldReference(Predicate left, Predicate right){
		checkFieldReference((HasFieldReference)left, (HasFieldReference)right);
	}

	static
	protected void checkFieldReference(HasFieldReference left, HasFieldReference right){
		String leftFieldName = left.requireField();
		String rightFieldName = right.requireField();

		if(!Objects.equals(leftFieldName, rightFieldName)){
			throw new IllegalArgumentException("Field names " + leftFieldName + " and " + rightFieldName + " are not the same");
		}
	}

	static
	protected void checkValue(Predicate left, Predicate right){
		checkValue((HasValue)left, (HasValue)right);
	}

	static
	protected void checkValue(HasValue left, HasValue right){
		Object leftValue = left.getValue();
		Object rightValue = right.getValue();

		if(!Objects.equals(leftValue, rightValue)){
			throw new IllegalArgumentException("Field values " + leftValue + " and " + rightValue + " are not the same");
		}
	}

	static
	protected void checkValueSet(Predicate left, Predicate right){
		checkValueSet((HasValueSet)left, (HasValueSet)right);
	}

	static
	protected void checkValueSet(HasValueSet left, HasValueSet right){
		Array leftArray = left.requireArray();
		Array rightArray = right.requireArray();

		if(!Objects.equals(leftArray.getValue(), rightArray.getValue())){
			throw new IllegalArgumentException("Field value sets " + leftArray.getValue() + " and " + rightArray.getValue() + " are not the same");
		}
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy