All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.drools.core.reteoo.ReteooBuilder Maven / Gradle / Ivy

There is a newer version: 9.44.0.Final
Show newest version
/*
 * Copyright 2005 Red Hat, Inc. and/or its affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.drools.core.reteoo;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import org.drools.core.common.BaseNode;
import org.drools.core.common.DroolsObjectInputStream;
import org.drools.core.common.DroolsObjectOutputStream;
import org.drools.core.common.InternalWorkingMemory;
import org.drools.core.common.MemoryFactory;
import org.drools.core.common.NetworkNode;
import org.drools.core.definitions.rule.impl.RuleImpl;
import org.drools.core.impl.InternalKnowledgeBase;
import org.drools.core.phreak.AddRemoveRule;
import org.drools.core.rule.InvalidPatternException;
import org.drools.core.rule.WindowDeclaration;
import org.kie.api.definition.rule.Rule;

import static org.drools.core.impl.StatefulKnowledgeSessionImpl.DEFAULT_RULE_UNIT;

/**
 * Builds the Rete-OO network for a Package.
 *
 */
public class ReteooBuilder
        implements
        Externalizable {
    // ------------------------------------------------------------
    // Instance members
    // ------------------------------------------------------------

    private static final long           serialVersionUID = 510l;

    /** The RuleBase */
    private transient InternalKnowledgeBase  kBase;

    private Map     rules;
    private Map     queries;

    private Map     namedWindows;

    private transient RuleBuilder       ruleBuilder;

    private IdGenerator                 idGenerator;

    // ------------------------------------------------------------
    // Constructors
    // ------------------------------------------------------------

    public ReteooBuilder() {

    }

    /**
     * Construct a Builder against an existing Rete
     * network.
     */
    public ReteooBuilder( final InternalKnowledgeBase  kBase ) {
        this.kBase = kBase;
        this.rules = new HashMap();
        this.queries = new HashMap();
        this.namedWindows = new HashMap();

        //Set to 1 as Rete node is set to 0
        this.idGenerator = new IdGenerator();
        this.ruleBuilder = kBase.getConfiguration().getComponentFactory().getRuleBuilderFactory().newRuleBuilder();
    }

    // ------------------------------------------------------------
    // Instance methods
    // ------------------------------------------------------------

    /**
     * Add a Rule to the network.
     *
     * @param rule
     *            The rule to add.
     * @throws InvalidPatternException
     */
    public synchronized void addRule(final RuleImpl rule) throws InvalidPatternException {
        final List terminals = this.ruleBuilder.addRule( rule,
                                                                       this.kBase );

        BaseNode[] nodes = terminals.toArray( new BaseNode[terminals.size()] );
        this.rules.put( rule.getFullyQualifiedName(), nodes );
        if (rule.isQuery()) {
            this.queries.put( rule.getName(), nodes );
        }
    }

    public void addEntryPoint( String id ) {
        this.ruleBuilder.addEntryPoint( id,
                                        this.kBase );
    }

    public synchronized void addNamedWindow( WindowDeclaration window ) {
        final WindowNode wnode = this.ruleBuilder.addWindowNode( window,
                                                                 this.kBase );

        this.namedWindows.put( window.getName(),
                               wnode );
    }

    public WindowNode getWindowNode( String name ) {
        return this.namedWindows.get( name );
    }

    public IdGenerator getIdGenerator() {
        return this.idGenerator;
    }

    public synchronized BaseNode[] getTerminalNodes(final RuleImpl rule) {
        return getTerminalNodes( rule.getFullyQualifiedName() );
    }

    public synchronized BaseNode[] getTerminalNodes(final String ruleName) {
        return this.rules.get( ruleName );
    }

    public synchronized BaseNode[] getTerminalNodesForQuery(final String ruleName) {
        BaseNode[] nodes = this.queries.get( ruleName );
        return nodes != null ? nodes : getTerminalNodes(ruleName);
    }

    public synchronized Map getTerminalNodes() {
        return this.rules;
    }

    public synchronized void removeRules(Collection rulesToBeRemoved) {
        // reset working memories for potential propagation
        InternalWorkingMemory[] workingMemories = this.kBase.getWorkingMemories();

        for (RuleImpl rule : rulesToBeRemoved) {
            if (rule.hasChildren() && !rulesToBeRemoved.containsAll( rule.getChildren() )) {
                throw new RuntimeException("Cannot remove parent rule " + rule + " without having removed all its chikdren");
            }

            final RuleRemovalContext context = new RuleRemovalContext( rule );
            context.setKnowledgeBase( kBase );

            for ( BaseNode node : rules.remove( rule.getFullyQualifiedName() ) ) {
                removeTerminalNode( context, (TerminalNode) node, workingMemories );
            }

            if ( rule.isQuery() ) {
                this.queries.remove( rule.getName() );
            }

            if (rule.getParent() != null && !rulesToBeRemoved.contains( rule.getParent() )) {
                rule.getParent().removeChild( rule );
            }
        }
    }

    public void removeTerminalNode(RuleRemovalContext context, TerminalNode tn, InternalWorkingMemory[] workingMemories)  {
        AddRemoveRule.removeRule( tn, workingMemories, kBase );

        BaseNode node = (BaseNode) tn;
        removeNodeAssociation(node, context.getRule());

        resetMasks(removeNodes((AbstractTerminalNode)tn, workingMemories, context));
    }

    private Collection removeNodes(AbstractTerminalNode terminalNode, InternalWorkingMemory[] wms, RuleRemovalContext context) {
        Map stillInUse = new HashMap();
        Collection alphas = new HashSet();

        removePath(wms, context, stillInUse, alphas, terminalNode);

        Set removedNodes = new HashSet();
        for (ObjectSource alpha : alphas) {
            removeObjectSource( wms, stillInUse, removedNodes, alpha, context );
        }

        return stillInUse.values();
    }

    /**
     * Path's must be removed starting from the outer most path, iterating towards the inner most path.
     * Each time it reaches a subnetwork beta node, the current path evaluation ends, and instead the subnetwork
     * path continues.
     */
    private void removePath( InternalWorkingMemory[] wms, RuleRemovalContext context, Map stillInUse, Collection alphas, PathEndNode endNode ) {
        LeftTupleNode[] nodes = endNode.getPathNodes();
        for (int i = endNode.getPositionInPath(); i >= 0; i--) {
            BaseNode node = (BaseNode) nodes[i];

            boolean removed = false;
            if ( NodeTypeEnums.isLeftTupleNode( node ) ) {
                removed = removeLeftTupleNode(wms, context, stillInUse, node);
            }

            if ( removed ) {
                // reteoo requires to call remove on the OTN for tuples cleanup
                if (NodeTypeEnums.isBetaNode(node) && !((BetaNode) node).isRightInputIsRiaNode()) {
                    alphas.add(((BetaNode) node).getRightInput());
                } else if (node.getType() == NodeTypeEnums.LeftInputAdapterNode) {
                    alphas.add(((LeftInputAdapterNode) node).getObjectSource());
                }
            }

            if (NodeTypeEnums.isBetaNode(node) && ((BetaNode) node).isRightInputIsRiaNode()) {
                endNode = (PathEndNode) ((BetaNode) node).getRightInput();
                removePath(wms, context, stillInUse, alphas, endNode);
                return;
            }
        }
    }

    private boolean removeLeftTupleNode(InternalWorkingMemory[] wms, RuleRemovalContext context, Map stillInUse, BaseNode node) {
        boolean removed;
        removed = node.remove(context, this, wms);

        if (removed) {
            stillInUse.remove( node.getId() );
            // phreak must clear node memories, although this should ideally be pushed into AddRemoveRule
            for (InternalWorkingMemory workingMemory : wms) {
                workingMemory.clearNodeMemory((MemoryFactory) node);
            }
        } else {
            stillInUse.put( node.getId(), node );
        }

        return removed;
    }

    private void removeObjectSource(InternalWorkingMemory[] wms, Map stillInUse, Set removedNodes, ObjectSource node, RuleRemovalContext context ) {
        if (removedNodes.contains( node.getId() )) {
            return;
        }
        ObjectSource parent = node.getParentObjectSource();

        boolean removed = node.remove( context, this, wms );

        if ( !removed ) {
            stillInUse.put( node.getId(), node );
        } else {
            stillInUse.remove(node.getId());
            removedNodes.add(node.getId());

            if ( node.getType() != NodeTypeEnums.ObjectTypeNode &&
                 node.getType() != NodeTypeEnums.AlphaNode ) {
                // phreak must clear node memories, although this should ideally be pushed into AddRemoveRule
                for (InternalWorkingMemory workingMemory : wms) {
                    workingMemory.clearNodeMemory( (MemoryFactory) node);
                }
            }

            if (parent != null && parent.getType() != NodeTypeEnums.EntryPointNode) {
                removeObjectSource(wms, stillInUse, removedNodes, parent, context);
            }
        }
    }

    private void removeNodeAssociation(BaseNode node, Rule rule) {
        if (node == null || !node.removeAssociation( rule )) {
            return;
        }
        if (node instanceof LeftTupleNode) {
            removeNodeAssociation( ((LeftTupleNode)node).getLeftTupleSource(), rule );
        }
        if ( NodeTypeEnums.isBetaNode( node ) ) {
            removeNodeAssociation( ((BetaNode) node).getRightInput(), rule );
        } else if ( node.getType() == NodeTypeEnums.LeftInputAdapterNode ) {
            removeNodeAssociation( ((LeftInputAdapterNode) node).getObjectSource(), rule );
        } else if ( node.getType() == NodeTypeEnums.AlphaNode ) {
            removeNodeAssociation( ((AlphaNode) node).getParentObjectSource(), rule );
        }
    }

    private void resetMasks(Collection nodes) {
        NodeSet leafSet = new NodeSet();

        for ( BaseNode node : nodes ) {
            if ( node.getType() == NodeTypeEnums.AlphaNode ) {
                ObjectSource source = (AlphaNode) node;
                while ( true ) {
                    source.resetInferredMask();
                    BaseNode parent = source.getParentObjectSource();
                    if (parent.getType() != NodeTypeEnums.AlphaNode) {
                        break;
                    }
                    source = (ObjectSource)parent;
                }
                updateLeafSet(source, leafSet );
            } else if( NodeTypeEnums.isBetaNode( node ) ) {
                BetaNode betaNode = ( BetaNode ) node;
                if ( betaNode.isInUse() ) {
                    leafSet.add( betaNode );
                }
            } else if ( NodeTypeEnums.isTerminalNode( node )  ) {
                RuleTerminalNode rtNode = ( RuleTerminalNode ) node;
                if ( rtNode.isInUse() ) {
                    leafSet.add( rtNode );
                }
            }
        }

        for ( BaseNode node : leafSet ) {
            if ( NodeTypeEnums.isTerminalNode( node ) ) {
                ((TerminalNode)node).initInferredMask();
            } else { // else node instanceof BetaNode
                ((BetaNode)node).initInferredMask();
            }
        }
    }

    private void updateLeafSet(BaseNode baseNode, NodeSet leafSet) {
        if ( baseNode.getType() == NodeTypeEnums.AlphaNode ) {
            for ( ObjectSink sink : ((AlphaNode) baseNode).getObjectSinkPropagator().getSinks() ) {
                if ( ((BaseNode)sink).isInUse() ) {
                    updateLeafSet( ( BaseNode ) sink, leafSet );
                }
            }
        } else  if ( baseNode.getType() ==  NodeTypeEnums.LeftInputAdapterNode ) {
            for ( LeftTupleSink sink : ((LeftInputAdapterNode) baseNode).getSinkPropagator().getSinks() ) {
                if ( sink.getType() ==  NodeTypeEnums.RuleTerminalNode ) {
                    leafSet.add( (BaseNode) sink );
                } else if ( ((BaseNode)sink).isInUse() ) {
                    updateLeafSet( ( BaseNode ) sink, leafSet );
                }
            }
        } else if ( baseNode.getType() == NodeTypeEnums.EvalConditionNode ) {
            for ( LeftTupleSink sink : ((EvalConditionNode) baseNode).getSinkPropagator().getSinks() ) {
                if ( ((BaseNode)sink).isInUse() ) {
                    updateLeafSet( ( BaseNode ) sink, leafSet );
                }
            }
        } else if ( NodeTypeEnums.isBetaNode( baseNode ) ) {
            if ( baseNode.isInUse() ) {
                leafSet.add( baseNode );
            }
        }
    }

    public static class IdGenerator implements Externalizable {
        private static final String DEFAULT_TOPIC = "DEFAULT_TOPIC";

        private Map generators = new ConcurrentHashMap<>();

        public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
            generators = (Map) in.readObject();
        }

        public void writeExternal(ObjectOutput out) throws IOException {
            out.writeObject( generators );
        }

        public int getNextId() {
            return getNextId( DEFAULT_TOPIC );
        }

        public int getNextId(String topic) {
            return generators.computeIfAbsent( topic, key -> new InternalIdGenerator( 1 ) ).getNextId();
        }

        public synchronized void releaseId( RuleImpl rule, NetworkNode node ) {
            generators.get( DEFAULT_TOPIC ).releaseId( node.getId() );
            if (node instanceof MemoryFactory) {
                String unit = rule != null && rule.getRuleUnitClassName() != null ? rule.getRuleUnitClassName() : DEFAULT_RULE_UNIT;
                generators.get( unit ).releaseId( ( (MemoryFactory) node ).getMemoryId() );
            }
        }

        public int getLastId() {
            return getLastId( DEFAULT_TOPIC );
        }

        public int getLastId(String topic) {
            InternalIdGenerator gen = generators.get( topic );
            return gen != null ? gen.getLastId() : 0;
        }
    }

    private static class InternalIdGenerator implements Externalizable {

        private static final long serialVersionUID = 510l;

        private Queue    recycledIds;
        private int               nextId;

        public InternalIdGenerator() { }

        public InternalIdGenerator(final int firstId) {
            this.nextId = firstId;
            this.recycledIds = new LinkedList();
        }

        @SuppressWarnings("unchecked")
        public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
            recycledIds = (Queue) in.readObject();
            nextId = in.readInt();
        }

        public void writeExternal(ObjectOutput out) throws IOException {
            out.writeObject( recycledIds );
            out.writeInt( nextId );
        }

        public synchronized int getNextId() {
            Integer id = this.recycledIds.poll();
            return ( id == null ) ? this.nextId++ : id;
        }

        public synchronized void releaseId(int id) {
            this.recycledIds.add( id );
        }

        public int getLastId() {
            return this.nextId - 1;
        }
    }

    public void writeExternal(ObjectOutput out) throws IOException {
        boolean isDrools = out instanceof DroolsObjectOutputStream;
        DroolsObjectOutputStream droolsStream;
        ByteArrayOutputStream bytes;

        if ( isDrools ) {
            bytes = null;
            droolsStream = (DroolsObjectOutputStream) out;
        } else {
            bytes = new ByteArrayOutputStream();
            droolsStream = new DroolsObjectOutputStream( bytes );
        }
        droolsStream.writeObject( rules );
        droolsStream.writeObject( queries );
        droolsStream.writeObject( namedWindows );
        droolsStream.writeObject( idGenerator );
        if ( !isDrools ) {
            droolsStream.flush();
            droolsStream.close();
            bytes.close();
            out.writeInt( bytes.size() );
            out.writeObject( bytes.toByteArray() );
        }
    }

    public void readExternal(ObjectInput in) throws IOException,
                                                    ClassNotFoundException {
        boolean isDrools = in instanceof DroolsObjectInputStream;
        DroolsObjectInputStream droolsStream;
        ByteArrayInputStream bytes;

        if ( isDrools ) {
            bytes = null;
            droolsStream = (DroolsObjectInputStream) in;
        } else {
            bytes = new ByteArrayInputStream( (byte[]) in.readObject() );
            droolsStream = new DroolsObjectInputStream( bytes );
        }

        this.rules = (Map) droolsStream.readObject();
        this.queries = (Map) droolsStream.readObject();
        this.namedWindows = (Map) droolsStream.readObject();
        this.idGenerator = (IdGenerator) droolsStream.readObject();
        if ( !isDrools ) {
            droolsStream.close();
            bytes.close();
        }

    }

    public void setRuleBase( InternalKnowledgeBase kBase ) {
        this.kBase = kBase;

        this.ruleBuilder = kBase.getConfiguration().getComponentFactory().getRuleBuilderFactory().newRuleBuilder();
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy