All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.optaplanner.constraint.streams.drools.common.AbstractPatternVariable Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.optaplanner.constraint.streams.drools.common;

import static org.drools.model.PatternDSL.betaIndexedBy;
import static org.optaplanner.constraint.streams.drools.common.AbstractLeftHandSide.getConstraintType;

import java.util.List;
import java.util.function.BiFunction;
import java.util.function.BiPredicate;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.drools.model.BetaIndex;
import org.drools.model.BetaIndex2;
import org.drools.model.BetaIndex3;
import org.drools.model.PatternDSL;
import org.drools.model.Variable;
import org.drools.model.functions.Function1;
import org.drools.model.functions.Predicate2;
import org.drools.model.functions.Predicate3;
import org.drools.model.functions.Predicate4;
import org.drools.model.view.ViewItem;
import org.optaplanner.constraint.streams.common.bi.DefaultBiJoiner;
import org.optaplanner.constraint.streams.common.quad.DefaultQuadJoiner;
import org.optaplanner.constraint.streams.common.tri.DefaultTriJoiner;
import org.optaplanner.core.api.function.QuadFunction;
import org.optaplanner.core.api.function.QuadPredicate;
import org.optaplanner.core.api.function.TriFunction;
import org.optaplanner.core.api.function.TriPredicate;
import org.optaplanner.core.impl.score.stream.JoinerType;

abstract class AbstractPatternVariable>
        implements PatternVariable {

    private final Variable primaryVariable;
    private final PatternDSL.PatternDef pattern;
    private final List> prerequisiteExpressions;
    private final List> dependentExpressions;

    protected AbstractPatternVariable(Variable aVariable, PatternDSL.PatternDef pattern,
            List> prerequisiteExpressions, List> dependentExpressions) {
        this.primaryVariable = aVariable;
        this.pattern = pattern;
        this.prerequisiteExpressions = prerequisiteExpressions;
        this.dependentExpressions = dependentExpressions;
    }

    protected AbstractPatternVariable(AbstractPatternVariable patternCreator, Variable boundVariable) {
        this.primaryVariable = boundVariable;
        this.pattern = patternCreator.getPattern();
        this.prerequisiteExpressions = patternCreator.getPrerequisiteExpressions();
        this.dependentExpressions = patternCreator.getDependentExpressions();
    }

    protected AbstractPatternVariable(AbstractPatternVariable patternCreator,
            ViewItem dependentExpression) {
        this.primaryVariable = patternCreator.primaryVariable;
        this.pattern = patternCreator.pattern;
        this.prerequisiteExpressions = patternCreator.prerequisiteExpressions;
        this.dependentExpressions = Stream.concat(patternCreator.dependentExpressions.stream(), Stream.of(dependentExpression))
                .collect(Collectors.toList());
    }

    @Override
    public Variable getPrimaryVariable() {
        return primaryVariable;
    }

    public PatternDSL.PatternDef getPattern() {
        return pattern;
    }

    @Override
    public List> getPrerequisiteExpressions() {
        return prerequisiteExpressions;
    }

    @Override
    public List> getDependentExpressions() {
        return dependentExpressions;
    }

    /**
     * Variable values can be either read directly from the pattern variable (see {@link DirectPatternVariable}
     * or indirectly by applying a mapping function to it (see {@link IndirectPatternVariable}.
     * This method abstracts this behavior, so that the surrounding code may be shared between both implementations.
     *
     * @param patternVar never null, pattern variable to extract the value from
     * @return value of the variable
     */
    protected abstract A extract(PatternVar_ patternVar);

    @Override
    public final Child_ filter(Predicate predicate) {
        pattern.expr("Filter using " + predicate, a -> predicate.test(extract(a)));
        return (Child_) this;
    }

    @Override
    public final  Child_ filter(BiPredicate predicate,
            Variable leftJoinVariable) {
        pattern.expr("Filter using " + predicate, leftJoinVariable,
                (a, leftJoinVar) -> predicate.test(leftJoinVar, extract(a)));
        return (Child_) this;
    }

    @Override
    public final  Child_ filter(
            TriPredicate predicate, Variable leftJoinVariableA,
            Variable leftJoinVariableB) {
        pattern.expr("Filter using " + predicate, leftJoinVariableA, leftJoinVariableB,
                (a, leftJoinVarA, leftJoinVarB) -> predicate.test(leftJoinVarA, leftJoinVarB, extract(a)));
        return (Child_) this;
    }

    @Override
    public final  Child_ filter(
            QuadPredicate predicate,
            Variable leftJoinVariableA, Variable leftJoinVariableB,
            Variable leftJoinVariableC) {
        pattern.expr("Filter using " + predicate, leftJoinVariableA, leftJoinVariableB, leftJoinVariableC,
                (a, leftJoinVarA, leftJoinVarB, leftJoinVarC) -> predicate.test(leftJoinVarA, leftJoinVarB, leftJoinVarC,
                        extract(a)));
        return (Child_) this;
    }

    @Override
    public final  Child_ filterForJoin(Variable leftJoinVar,
            DefaultBiJoiner joiner, JoinerType joinerType, int mappingIndex) {
        Function leftMapping = joiner.getLeftMapping(mappingIndex);
        Function rightMapping = joiner.getRightMapping(mappingIndex);
        Function1 rightExtractor = b -> rightMapping.apply(extract(b));
        Predicate2 predicate =
                (b, a) -> joinerType.matches(leftMapping.apply(a), rightExtractor.apply(b));
        BetaIndex index =
                createBetaIndex(joinerType, mappingIndex, leftMapping, rightExtractor);
        pattern.expr("Join using joiner #" + mappingIndex + " in " + joiner, leftJoinVar, predicate, index);
        return (Child_) this;
    }

    private  BetaIndex createBetaIndex(JoinerType joinerType, int mappingIndex,
            Function leftMapping, Function1 rightExtractor) {
        if (joinerType == JoinerType.EQUAL) {
            return betaIndexedBy(Object.class, getConstraintType(joinerType), mappingIndex, rightExtractor, leftMapping::apply,
                    Object.class);
        } else { // Drools beta index on LT/LTE/GT/GTE requires Comparable.
            JoinerType reversedJoinerType = joinerType.flip();
            return betaIndexedBy(Comparable.class, getConstraintType(reversedJoinerType), mappingIndex,
                    c -> (Comparable) rightExtractor.apply(c), leftMapping::apply, Comparable.class);
        }
    }

    @Override
    public final  Child_ filterForJoin(Variable leftJoinVarA,
            Variable leftJoinVarB, DefaultTriJoiner joiner,
            JoinerType joinerType, int mappingIndex) {
        BiFunction leftMapping = joiner.getLeftMapping(mappingIndex);
        Function rightMapping = joiner.getRightMapping(mappingIndex);
        Function1 rightExtractor = b -> rightMapping.apply(extract(b));
        Predicate3 predicate =
                (c, a, b) -> joinerType.matches(leftMapping.apply(a, b), rightExtractor.apply(c));
        BetaIndex2 index =
                createBetaIndex(joinerType, mappingIndex, leftMapping, rightExtractor);
        pattern.expr("Join using joiner #" + mappingIndex + " in " + joiner, leftJoinVarA, leftJoinVarB, predicate, index);
        return (Child_) this;
    }

    private  BetaIndex2 createBetaIndex(
            JoinerType joinerType,
            int mappingIndex, BiFunction leftMapping,
            Function1 rightExtractor) {
        if (joinerType == JoinerType.EQUAL) {
            return betaIndexedBy(Object.class, getConstraintType(joinerType), mappingIndex, rightExtractor, leftMapping::apply,
                    Object.class);
        } else { // Drools beta index on LT/LTE/GT/GTE requires Comparable.
            JoinerType reversedJoinerType = joinerType.flip();
            return betaIndexedBy(Comparable.class, getConstraintType(reversedJoinerType), mappingIndex,
                    c -> (Comparable) rightExtractor.apply(c), leftMapping::apply, Comparable.class);
        }
    }

    @Override
    public final  Child_ filterForJoin(Variable leftJoinVarA,
            Variable leftJoinVarB, Variable leftJoinVarC,
            DefaultQuadJoiner joiner, JoinerType joinerType,
            int mappingIndex) {
        TriFunction leftMapping =
                joiner.getLeftMapping(mappingIndex);
        Function rightMapping = joiner.getRightMapping(mappingIndex);
        Function1 rightExtractor = b -> rightMapping.apply(extract(b));
        Predicate4 predicate =
                (d, a, b, c) -> joinerType.matches(leftMapping.apply(a, b, c), rightExtractor.apply(d));
        BetaIndex3 index =
                createBetaIndex(joinerType, mappingIndex, leftMapping, rightExtractor);
        pattern.expr("Join using joiner #" + mappingIndex + " in " + joiner, leftJoinVarA, leftJoinVarB,
                leftJoinVarC, predicate, index);
        return (Child_) this;
    }

    private 
            BetaIndex3 createBetaIndex(JoinerType joinerType,
                    int mappingIndex, TriFunction leftMapping,
                    Function1 rightExtractor) {
        if (joinerType == JoinerType.EQUAL) {
            return betaIndexedBy(Object.class, getConstraintType(joinerType), mappingIndex, rightExtractor, leftMapping::apply,
                    Object.class);
        } else { // Drools beta index on LT/LTE/GT/GTE requires Comparable.
            JoinerType reversedJoinerType = joinerType.flip();
            return betaIndexedBy(Comparable.class, getConstraintType(reversedJoinerType), mappingIndex,
                    c -> (Comparable) rightExtractor.apply(c), leftMapping::apply, Comparable.class);
        }
    }

    @Override
    public final  Child_ bind(Variable boundVariable, Function bindingFunction) {
        pattern.bind(boundVariable, a -> bindingFunction.apply(extract(a)));
        return (Child_) this;
    }

    @Override
    public final  Child_ bind(Variable boundVariable,
            Variable leftJoinVariable, BiFunction bindingFunction) {
        pattern.bind(boundVariable, leftJoinVariable,
                (a, leftJoinVar) -> bindingFunction.apply(extract(a), leftJoinVar));
        return (Child_) this;
    }

    @Override
    public final  Child_ bind(Variable boundVariable,
            Variable leftJoinVariableA, Variable leftJoinVariableB,
            TriFunction bindingFunction) {
        pattern.bind(boundVariable, leftJoinVariableA, leftJoinVariableB,
                (a, leftJoinVarA, leftJoinVarB) -> bindingFunction.apply(extract(a), leftJoinVarA, leftJoinVarB));
        return (Child_) this;
    }

    @Override
    public final  Child_ bind(Variable boundVariable,
            Variable leftJoinVariableA, Variable leftJoinVariableB,
            Variable leftJoinVariableC,
            QuadFunction bindingFunction) {
        pattern.bind(boundVariable, leftJoinVariableA, leftJoinVariableB, leftJoinVariableC,
                (a, leftJoinVarA, leftJoinVarB, leftJoinVarC) -> bindingFunction.apply(extract(a), leftJoinVarA,
                        leftJoinVarB, leftJoinVarC));
        return (Child_) this;
    }

    @Override
    public final List> build() {
        Stream> prerequisites = prerequisiteExpressions.stream();
        Stream> dependents = dependentExpressions.stream();
        return Stream.concat(Stream.concat(prerequisites, Stream.of(pattern)), dependents)
                .collect(Collectors.toList());
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy