All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.optaplanner.constraint.streams.drools.common.BiLeftHandSide Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.optaplanner.constraint.streams.drools.common;

import static java.util.Collections.singletonList;
import static org.drools.model.DSL.exists;
import static org.drools.model.DSL.not;
import static org.drools.model.PatternDSL.betaIndexedBy;
import static org.drools.model.PatternDSL.pattern;

import java.math.BigDecimal;
import java.util.List;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.BiPredicate;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.function.Predicate;
import java.util.function.ToIntBiFunction;
import java.util.function.ToLongBiFunction;
import java.util.stream.Stream;

import org.drools.model.BetaIndex2;
import org.drools.model.DSL;
import org.drools.model.PatternDSL;
import org.drools.model.Variable;
import org.drools.model.functions.Function2;
import org.drools.model.functions.Predicate3;
import org.drools.model.functions.accumulate.AccumulateFunction;
import org.drools.model.view.ViewItem;
import org.optaplanner.constraint.streams.common.tri.DefaultTriJoiner;
import org.optaplanner.constraint.streams.common.tri.FilteringTriJoiner;
import org.optaplanner.constraint.streams.drools.DroolsVariableFactory;
import org.optaplanner.core.api.function.TriPredicate;
import org.optaplanner.core.api.score.stream.bi.BiConstraintCollector;
import org.optaplanner.core.api.score.stream.tri.TriJoiner;
import org.optaplanner.core.impl.score.stream.JoinerType;

/**
 * Represents the left hand side of a Drools rule, the result of which are two variables.
 * The simplest variant of such rule, with no filters or groupBys applied, would look like this in equivalent DRL:
 *
 * 
 * {@code
 *  rule "Simplest bivariate rule"
 *  when
 *      $a: Something()
 *      $b: SomethingElse()
 *  then
 *      // Do something with the $a and $b variables.
 *  end
 * }
 * 
* * Usually though, there would be a joiner between the two, limiting the cartesian product: * *
 * {@code
 *  rule "Bivariate join rule"
 *  when
 *      $a: Something($leftJoin: someValue)
 *      $b: SomethingElse(someOtherValue == $leftJoin)
 *  then
 *      // Do something with the $a and $b variables.
 *  end
 * }
 * 
* * For more, see {@link UniLeftHandSide}. * * @param generic type of the first resulting variable * @param generic type of the second resulting variable */ public final class BiLeftHandSide extends AbstractLeftHandSide { private final PatternVariable patternVariableA; private final PatternVariable patternVariableB; private final BiRuleContext ruleContext; BiLeftHandSide(Variable left, PatternVariable right, DroolsVariableFactory variableFactory) { this(new DetachedPatternVariable<>(left), right, variableFactory); } BiLeftHandSide(PatternVariable left, PatternVariable right, DroolsVariableFactory variableFactory) { super(variableFactory); this.patternVariableA = Objects.requireNonNull(left); this.patternVariableB = Objects.requireNonNull(right); this.ruleContext = buildRuleContext(); } private BiRuleContext buildRuleContext() { ViewItem[] viewItems = Stream.of(patternVariableA, patternVariableB) .flatMap(variable -> variable.build().stream()) .toArray((IntFunction[]>) ViewItem[]::new); return new BiRuleContext<>(patternVariableA.getPrimaryVariable(), patternVariableB.getPrimaryVariable(), viewItems); } public BiLeftHandSide andFilter(BiPredicate predicate) { return new BiLeftHandSide<>(patternVariableA, patternVariableB.filter(predicate, patternVariableA.getPrimaryVariable()), variableFactory); } private BiLeftHandSide applyJoiners(Class otherFactType, Predicate nullityFilter, DefaultTriJoiner joiner, TriPredicate predicate, boolean shouldExist) { Variable toExist = variableFactory.createVariable(otherFactType, "toExist"); PatternDSL.PatternDef existencePattern = pattern(toExist); if (nullityFilter != null) { existencePattern = existencePattern.expr("Exclude nulls using " + nullityFilter, nullityFilter::test); } if (joiner == null) { return applyFilters(existencePattern, predicate, shouldExist); } int joinerCount = joiner.getJoinerCount(); for (int mappingIndex = 0; mappingIndex < joinerCount; mappingIndex++) { JoinerType joinerType = joiner.getJoinerType(mappingIndex); BiFunction leftMapping = joiner.getLeftMapping(mappingIndex); Function rightMapping = joiner.getRightMapping(mappingIndex); Predicate3 joinPredicate = (c, a, b) -> joinerType.matches(leftMapping.apply(a, b), rightMapping.apply(c)); existencePattern = existencePattern.expr("Join using joiner #" + mappingIndex + " in " + joiner, patternVariableA.getPrimaryVariable(), patternVariableB.getPrimaryVariable(), joinPredicate, createBetaIndex(joiner, mappingIndex)); } return applyFilters(existencePattern, predicate, shouldExist); } private BetaIndex2 createBetaIndex(DefaultTriJoiner joiner, int mappingIndex) { JoinerType joinerType = joiner.getJoinerType(mappingIndex); BiFunction leftMapping = joiner.getLeftMapping(mappingIndex); Function rightMapping = joiner.getRightMapping(mappingIndex); if (joinerType == JoinerType.EQUAL) { return betaIndexedBy(Object.class, getConstraintType(joinerType), mappingIndex, rightMapping::apply, leftMapping::apply, Object.class); } else { // Drools beta index on LT/LTE/GT/GTE requires Comparable. JoinerType reversedJoinerType = joinerType.flip(); return betaIndexedBy(Comparable.class, getConstraintType(reversedJoinerType), mappingIndex, c -> (Comparable) rightMapping.apply(c), leftMapping::apply, Comparable.class); } } private BiLeftHandSide applyFilters(PatternDSL.PatternDef existencePattern, TriPredicate predicate, boolean shouldExist) { PatternDSL.PatternDef possiblyFilteredExistencePattern = predicate == null ? existencePattern : existencePattern.expr("Filter using " + predicate, patternVariableA.getPrimaryVariable(), patternVariableB.getPrimaryVariable(), (c, a, b) -> predicate.test(a, b, c)); ViewItem existenceExpression = exists(possiblyFilteredExistencePattern); if (!shouldExist) { existenceExpression = not(possiblyFilteredExistencePattern); } return new BiLeftHandSide<>(patternVariableA, patternVariableB.addDependentExpression(existenceExpression), variableFactory); } private BiLeftHandSide existsOrNot(Class cClass, TriJoiner[] joiners, Predicate nullityFilter, boolean shouldExist) { int indexOfFirstFilter = -1; // Prepare the joiner and filter that will be used in the pattern DefaultTriJoiner finalJoiner = null; TriPredicate finalFilter = null; for (int i = 0; i < joiners.length; i++) { TriJoiner joiner = joiners[i]; boolean hasAFilter = indexOfFirstFilter >= 0; if (joiner instanceof FilteringTriJoiner) { if (!hasAFilter) { // From now on, we only allow filtering joiners. indexOfFirstFilter = i; } // Merge all filters into one to avoid paying the penalty for lack of indexing more than once. FilteringTriJoiner castJoiner = (FilteringTriJoiner) joiner; finalFilter = finalFilter == null ? castJoiner.getFilter() : finalFilter.and(castJoiner.getFilter()); } else { if (hasAFilter) { throw new IllegalStateException("Indexing joiner (" + joiner + ") must not follow a filtering joiner (" + joiners[indexOfFirstFilter] + ")."); } else { // Merge this Joiner with the existing Joiners. DefaultTriJoiner castJoiner = (DefaultTriJoiner) joiner; finalJoiner = finalJoiner == null ? castJoiner : finalJoiner.and(castJoiner); } } } return applyJoiners(cClass, nullityFilter, finalJoiner, finalFilter, shouldExist); } public BiLeftHandSide andExists(Class cClass, TriJoiner[] joiners, Predicate nullityFilter) { return existsOrNot(cClass, joiners, nullityFilter, true); } public BiLeftHandSide andNotExists(Class cClass, TriJoiner[] joiners, Predicate nullityFilter) { return existsOrNot(cClass, joiners, nullityFilter, false); } public TriLeftHandSide andJoin(UniLeftHandSide right, TriJoiner joiner) { DefaultTriJoiner castJoiner = (DefaultTriJoiner) joiner; PatternVariable newRight = right.getPatternVariableA(); int joinerCount = castJoiner.getJoinerCount(); for (int mappingIndex = 0; mappingIndex < joinerCount; mappingIndex++) { JoinerType joinerType = castJoiner.getJoinerType(mappingIndex); newRight = newRight.filterForJoin(patternVariableA.getPrimaryVariable(), patternVariableB.getPrimaryVariable(), castJoiner, joinerType, mappingIndex); } return new TriLeftHandSide<>(patternVariableA, patternVariableB, newRight, variableFactory); } public UniLeftHandSide andGroupBy(BiConstraintCollector collector) { Variable accumulateOutput = variableFactory.createVariable("collected"); ViewItem innerAccumulatePattern = joinViewItemsWithLogicalAnd(patternVariableA, patternVariableB); ViewItem outerAccumulatePattern = buildAccumulate(innerAccumulatePattern, createAccumulateFunction(collector, accumulateOutput)); return new UniLeftHandSide<>(accumulateOutput, singletonList(outerAccumulatePattern), variableFactory); } public BiLeftHandSide andGroupBy(BiConstraintCollector collectorA, BiConstraintCollector collectorB) { Variable accumulateOutputA = variableFactory.createVariable("collectedA"); Variable accumulateOutputB = variableFactory.createVariable("collectedB"); ViewItem innerAccumulatePattern = joinViewItemsWithLogicalAnd(patternVariableA, patternVariableB); ViewItem outerAccumulatePattern = buildAccumulate(innerAccumulatePattern, createAccumulateFunction(collectorA, accumulateOutputA), createAccumulateFunction(collectorB, accumulateOutputB)); return new BiLeftHandSide<>(accumulateOutputA, new DirectPatternVariable<>(accumulateOutputB, outerAccumulatePattern), variableFactory); } public TriLeftHandSide andGroupBy( BiConstraintCollector collectorA, BiConstraintCollector collectorB, BiConstraintCollector collectorC) { Variable accumulateOutputA = variableFactory.createVariable("collectedA"); Variable accumulateOutputB = variableFactory.createVariable("collectedB"); Variable accumulateOutputC = variableFactory.createVariable("collectedC"); ViewItem innerAccumulatePattern = joinViewItemsWithLogicalAnd(patternVariableA, patternVariableB); ViewItem outerAccumulatePattern = buildAccumulate(innerAccumulatePattern, createAccumulateFunction(collectorA, accumulateOutputA), createAccumulateFunction(collectorB, accumulateOutputB), createAccumulateFunction(collectorC, accumulateOutputC)); return new TriLeftHandSide<>(accumulateOutputA, accumulateOutputB, new DirectPatternVariable<>(accumulateOutputC, outerAccumulatePattern), variableFactory); } public QuadLeftHandSide andGroupBy( BiConstraintCollector collectorA, BiConstraintCollector collectorB, BiConstraintCollector collectorC, BiConstraintCollector collectorD) { Variable accumulateOutputA = variableFactory.createVariable("collectedA"); Variable accumulateOutputB = variableFactory.createVariable("collectedB"); Variable accumulateOutputC = variableFactory.createVariable("collectedC"); Variable accumulateOutputD = variableFactory.createVariable("collectedD"); ViewItem innerAccumulatePattern = joinViewItemsWithLogicalAnd(patternVariableA, patternVariableB); ViewItem outerAccumulatePattern = buildAccumulate(innerAccumulatePattern, createAccumulateFunction(collectorA, accumulateOutputA), createAccumulateFunction(collectorB, accumulateOutputB), createAccumulateFunction(collectorC, accumulateOutputC), createAccumulateFunction(collectorD, accumulateOutputD)); return new QuadLeftHandSide<>(accumulateOutputA, accumulateOutputB, accumulateOutputC, new DirectPatternVariable<>(accumulateOutputD, outerAccumulatePattern), variableFactory); } /** * Creates a Drools accumulate function based on a given collector. The accumulate function will take the pattern * variables as input and return its result into another {@link Variable}. * * @param type of the accumulate result * @param collector collector to use in the accumulate function * @param out variable in which to store accumulate result * @return Drools accumulate function */ private AccumulateFunction createAccumulateFunction(BiConstraintCollector collector, Variable out) { Variable variableA = patternVariableA.getPrimaryVariable(); Variable variableB = patternVariableB.getPrimaryVariable(); return new AccumulateFunction(null, () -> new BiAccumulator<>(variableA, variableB, collector)) .with(variableA, variableB) .as(out); } public UniLeftHandSide andGroupBy(BiFunction keyMapping) { Variable groupKey = variableFactory.createVariable("groupKey"); ViewItem groupByPattern = buildGroupBy(groupKey, keyMapping::apply); return new UniLeftHandSide<>(groupKey, singletonList(groupByPattern), variableFactory); } public BiLeftHandSide andGroupBy(BiFunction keyMappingA, BiConstraintCollector collectorB) { Variable groupKey = variableFactory.createVariable("groupKey"); Variable accumulateOutput = variableFactory.createVariable("output"); ViewItem groupByPattern = buildGroupBy(groupKey, keyMappingA::apply, createAccumulateFunction(collectorB, accumulateOutput)); return new BiLeftHandSide<>(groupKey, new DirectPatternVariable<>(accumulateOutput, groupByPattern), variableFactory); } public TriLeftHandSide andGroupBy(BiFunction keyMappingA, BiConstraintCollector collectorB, BiConstraintCollector collectorC) { Variable groupKey = variableFactory.createVariable("groupKey"); Variable accumulateOutputB = variableFactory.createVariable("outputB"); Variable accumulateOutputC = variableFactory.createVariable("outputC"); ViewItem groupByPattern = buildGroupBy(groupKey, keyMappingA::apply, createAccumulateFunction(collectorB, accumulateOutputB), createAccumulateFunction(collectorC, accumulateOutputC)); return new TriLeftHandSide<>(groupKey, accumulateOutputB, new DirectPatternVariable<>(accumulateOutputC, groupByPattern), variableFactory); } public QuadLeftHandSide andGroupBy( BiFunction keyMappingA, BiConstraintCollector collectorB, BiConstraintCollector collectorC, BiConstraintCollector collectorD) { Variable groupKey = variableFactory.createVariable("groupKey"); Variable accumulateOutputB = variableFactory.createVariable("outputB"); Variable accumulateOutputC = variableFactory.createVariable("outputC"); Variable accumulateOutputD = variableFactory.createVariable("outputD"); ViewItem groupByPattern = buildGroupBy(groupKey, keyMappingA::apply, createAccumulateFunction(collectorB, accumulateOutputB), createAccumulateFunction(collectorC, accumulateOutputC), createAccumulateFunction(collectorD, accumulateOutputD)); return new QuadLeftHandSide<>(groupKey, accumulateOutputB, accumulateOutputC, new DirectPatternVariable<>(accumulateOutputD, groupByPattern), variableFactory); } /** * Takes group key mappings and merges them in such a way that the result is a single composite key. * This is necessary because Drools groupBy can only take a single key - therefore multiple variables need to be * converted into a singular composite variable. * * @param keyMappingA mapping for the first variable * @param keyMappingB mapping for the second variable * @param generic type of the first variable * @param generic type of the second variable * @return never null, Drools function to convert the keys to a singular composite key */ private Function2> createCompositeBiGroupKey( BiFunction keyMappingA, BiFunction keyMappingB) { return (a, b) -> new BiTuple<>(keyMappingA.apply(a, b), keyMappingB.apply(a, b)); } public BiLeftHandSide andGroupBy(BiFunction keyMappingA, BiFunction keyMappingB) { Variable> groupKey = variableFactory.createVariable(BiTuple.class, "groupKey"); ViewItem groupByPattern = buildGroupBy(groupKey, createCompositeBiGroupKey(keyMappingA, keyMappingB)); Variable newA = variableFactory.createVariable("newA"); Variable newB = variableFactory.createVariable("newB"); IndirectPatternVariable> bPatternVar = decompose(groupKey, groupByPattern, newA, newB); return new BiLeftHandSide<>(newA, bPatternVar, variableFactory); } public TriLeftHandSide andGroupBy(BiFunction keyMappingA, BiFunction keyMappingB, BiConstraintCollector collectorC) { Variable> groupKey = variableFactory.createVariable(BiTuple.class, "groupKey"); Variable accumulateOutput = variableFactory.createVariable("output"); ViewItem groupByPattern = buildGroupBy(groupKey, createCompositeBiGroupKey(keyMappingA, keyMappingB), createAccumulateFunction(collectorC, accumulateOutput)); Variable newA = variableFactory.createVariable("newA"); Variable newB = variableFactory.createVariable("newB"); DirectPatternVariable cPatternVar = decomposeWithAccumulate(groupKey, groupByPattern, newA, newB, accumulateOutput); return new TriLeftHandSide<>(newA, newB, cPatternVar, variableFactory); } public QuadLeftHandSide andGroupBy(BiFunction keyMappingA, BiFunction keyMappingB, BiConstraintCollector collectorC, BiConstraintCollector collectorD) { Variable> groupKey = variableFactory.createVariable(BiTuple.class, "groupKey"); Variable accumulateOutputC = variableFactory.createVariable("outputC"); Variable accumulateOutputD = variableFactory.createVariable("outputD"); ViewItem groupByPattern = buildGroupBy(groupKey, createCompositeBiGroupKey(keyMappingA, keyMappingB), createAccumulateFunction(collectorC, accumulateOutputC), createAccumulateFunction(collectorD, accumulateOutputD)); Variable newA = variableFactory.createVariable("newA"); Variable newB = variableFactory.createVariable("newB"); DirectPatternVariable dPatternVar = decomposeWithAccumulate(groupKey, groupByPattern, newA, newB, accumulateOutputD); return new QuadLeftHandSide<>(newA, newB, accumulateOutputC, dPatternVar, variableFactory); } /** * Takes group key mappings and merges them in such a way that the result is a single composite key. * This is necessary because Drools groupBy can only take a single key - therefore multiple variables need to be * converted into a singular composite variable. * * @param keyMappingA mapping for the first variable * @param keyMappingB mapping for the second variable * @param keyMappingC mapping for the third variable * @param generic type of the first variable * @param generic type of the second variable * @param generic type of the third variable * @return never null, Drools function to convert the keys to a singular composite key */ private Function2> createCompositeTriGroupKey( BiFunction keyMappingA, BiFunction keyMappingB, BiFunction keyMappingC) { return (a, b) -> new TriTuple<>(keyMappingA.apply(a, b), keyMappingB.apply(a, b), keyMappingC.apply(a, b)); } public TriLeftHandSide andGroupBy(BiFunction keyMappingA, BiFunction keyMappingB, BiFunction keyMappingC) { Variable> groupKey = variableFactory.createVariable(TriTuple.class, "groupKey"); ViewItem groupByPattern = buildGroupBy(groupKey, createCompositeTriGroupKey(keyMappingA, keyMappingB, keyMappingC)); Variable newA = variableFactory.createVariable("newA"); Variable newB = variableFactory.createVariable("newB"); Variable newC = variableFactory.createVariable("newC"); IndirectPatternVariable> cPatternVar = decompose(groupKey, groupByPattern, newA, newB, newC); return new TriLeftHandSide<>(newA, newB, cPatternVar, variableFactory); } public QuadLeftHandSide andGroupBy( BiFunction keyMappingA, BiFunction keyMappingB, BiFunction keyMappingC, BiConstraintCollector collectorD) { Variable> groupKey = variableFactory.createVariable(TriTuple.class, "groupKey"); Variable accumulateOutputD = variableFactory.createVariable("outputD"); ViewItem groupByPattern = buildGroupBy(groupKey, createCompositeTriGroupKey(keyMappingA, keyMappingB, keyMappingC), createAccumulateFunction(collectorD, accumulateOutputD)); Variable newA = variableFactory.createVariable("newA"); Variable newB = variableFactory.createVariable("newB"); Variable newC = variableFactory.createVariable("newC"); DirectPatternVariable dPatternVar = decomposeWithAccumulate(groupKey, groupByPattern, newA, newB, newC, accumulateOutputD); return new QuadLeftHandSide<>(newA, newB, newC, dPatternVar, variableFactory); } /** * Takes group key mappings and merges them in such a way that the result is a single composite key. * This is necessary because Drools groupBy can only take a single key - therefore multiple variables need to be * converted into a singular composite variable. * * @param keyMappingA mapping for the first variable * @param keyMappingB mapping for the second variable * @param keyMappingC mapping for the third variable * @param generic type of the first variable * @param generic type of the second variable * @param generic type of the third variable * @return never null, Drools function to convert the keys to a singular composite key */ private Function2> createCompositeQuadGroupKey(BiFunction keyMappingA, BiFunction keyMappingB, BiFunction keyMappingC, BiFunction keyMappingD) { return (a, b) -> new QuadTuple<>(keyMappingA.apply(a, b), keyMappingB.apply(a, b), keyMappingC.apply(a, b), keyMappingD.apply(a, b)); } public QuadLeftHandSide andGroupBy( BiFunction keyMappingA, BiFunction keyMappingB, BiFunction keyMappingC, BiFunction keyMappingD) { Variable> groupKey = variableFactory.createVariable(QuadTuple.class, "groupKey"); ViewItem groupByPattern = buildGroupBy(groupKey, createCompositeQuadGroupKey(keyMappingA, keyMappingB, keyMappingC, keyMappingD)); Variable newA = variableFactory.createVariable("newA"); Variable newB = variableFactory.createVariable("newB"); Variable newC = variableFactory.createVariable("newC"); Variable newD = variableFactory.createVariable("newD"); IndirectPatternVariable> dPatternVar = decompose(groupKey, groupByPattern, newA, newB, newC, newD); return new QuadLeftHandSide<>(newA, newB, newC, dPatternVar, variableFactory); } public UniLeftHandSide andMap(BiFunction mapping) { Variable newA = variableFactory.createVariable("mapped", patternVariableA.getPrimaryVariable(), patternVariableB.getPrimaryVariable(), mapping); List> allPrerequisites = mergeViewItems(patternVariableA, patternVariableB); DirectPatternVariable newPatternVariableA = new DirectPatternVariable<>(newA, allPrerequisites); return new UniLeftHandSide<>(newPatternVariableA, variableFactory); } public BiLeftHandSide andFlattenLast(Function> mapping) { Variable source = patternVariableB.getPrimaryVariable(); Variable newB = variableFactory.createFlattenedVariable("flattened", source, mapping); List> allPrerequisites = mergeViewItems(patternVariableA, patternVariableB); PatternVariable newPatternVariableB = new DirectPatternVariable<>(newB, allPrerequisites); return new BiLeftHandSide<>(patternVariableA.getPrimaryVariable(), newPatternVariableB, variableFactory); } public RuleBuilder andTerminate(ToIntBiFunction matchWeigher) { return ruleContext.newRuleBuilder(matchWeigher); } public RuleBuilder andTerminate(ToLongBiFunction matchWeigher) { return ruleContext.newRuleBuilder(matchWeigher); } public RuleBuilder andTerminate(BiFunction matchWeigher) { return ruleContext.newRuleBuilder(matchWeigher); } private ViewItem buildGroupBy(Variable groupKey, Function2 groupKeyExtractor, AccumulateFunction... accFunctions) { Variable inputA = patternVariableA.getPrimaryVariable(); Variable inputB = patternVariableB.getPrimaryVariable(); ViewItem innerGroupByPattern = joinViewItemsWithLogicalAnd(patternVariableA, patternVariableB); return DSL.groupBy(innerGroupByPattern, inputA, inputB, groupKey, groupKeyExtractor, accFunctions); } }