
org.optaplanner.constraint.streams.drools.common.AbstractPatternVariable Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.optaplanner.constraint.streams.drools.common;
import static org.drools.model.PatternDSL.betaIndexedBy;
import static org.optaplanner.constraint.streams.drools.common.AbstractLeftHandSide.getConstraintType;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.BiPredicate;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.drools.model.BetaIndex;
import org.drools.model.BetaIndex2;
import org.drools.model.BetaIndex3;
import org.drools.model.PatternDSL;
import org.drools.model.Variable;
import org.drools.model.functions.Function1;
import org.drools.model.functions.Predicate2;
import org.drools.model.functions.Predicate3;
import org.drools.model.functions.Predicate4;
import org.drools.model.view.ViewItem;
import org.optaplanner.constraint.streams.common.bi.DefaultBiJoiner;
import org.optaplanner.constraint.streams.common.quad.DefaultQuadJoiner;
import org.optaplanner.constraint.streams.common.tri.DefaultTriJoiner;
import org.optaplanner.core.api.function.QuadFunction;
import org.optaplanner.core.api.function.QuadPredicate;
import org.optaplanner.core.api.function.TriFunction;
import org.optaplanner.core.api.function.TriPredicate;
import org.optaplanner.core.impl.score.stream.JoinerType;
abstract class AbstractPatternVariable>
implements PatternVariable {
private final Variable primaryVariable;
private final PatternDSL.PatternDef pattern;
private final List> prerequisiteExpressions;
private final List> dependentExpressions;
protected AbstractPatternVariable(Variable aVariable, PatternDSL.PatternDef pattern,
List> prerequisiteExpressions, List> dependentExpressions) {
this.primaryVariable = aVariable;
this.pattern = pattern;
this.prerequisiteExpressions = prerequisiteExpressions;
this.dependentExpressions = dependentExpressions;
}
protected AbstractPatternVariable(AbstractPatternVariable, PatternVar_, ?> patternCreator, Variable boundVariable) {
this.primaryVariable = boundVariable;
this.pattern = patternCreator.getPattern();
this.prerequisiteExpressions = patternCreator.getPrerequisiteExpressions();
this.dependentExpressions = patternCreator.getDependentExpressions();
}
protected AbstractPatternVariable(AbstractPatternVariable patternCreator,
ViewItem> dependentExpression) {
this.primaryVariable = patternCreator.primaryVariable;
this.pattern = patternCreator.pattern;
this.prerequisiteExpressions = patternCreator.prerequisiteExpressions;
this.dependentExpressions = Stream.concat(patternCreator.dependentExpressions.stream(), Stream.of(dependentExpression))
.collect(Collectors.toList());
}
@Override
public Variable getPrimaryVariable() {
return primaryVariable;
}
public PatternDSL.PatternDef getPattern() {
return pattern;
}
@Override
public List> getPrerequisiteExpressions() {
return prerequisiteExpressions;
}
@Override
public List> getDependentExpressions() {
return dependentExpressions;
}
/**
* Variable values can be either read directly from the pattern variable (see {@link DirectPatternVariable}
* or indirectly by applying a mapping function to it (see {@link IndirectPatternVariable}.
* This method abstracts this behavior, so that the surrounding code may be shared between both implementations.
*
* @param patternVar never null, pattern variable to extract the value from
* @return value of the variable
*/
protected abstract A extract(PatternVar_ patternVar);
@Override
public final Child_ filter(Predicate predicate) {
pattern.expr("Filter using " + predicate, a -> predicate.test(extract(a)));
return (Child_) this;
}
@Override
public final Child_ filter(BiPredicate predicate,
Variable leftJoinVariable) {
pattern.expr("Filter using " + predicate, leftJoinVariable,
(a, leftJoinVar) -> predicate.test(leftJoinVar, extract(a)));
return (Child_) this;
}
@Override
public final Child_ filter(
TriPredicate predicate, Variable leftJoinVariableA,
Variable leftJoinVariableB) {
pattern.expr("Filter using " + predicate, leftJoinVariableA, leftJoinVariableB,
(a, leftJoinVarA, leftJoinVarB) -> predicate.test(leftJoinVarA, leftJoinVarB, extract(a)));
return (Child_) this;
}
@Override
public final Child_ filter(
QuadPredicate predicate,
Variable leftJoinVariableA, Variable leftJoinVariableB,
Variable leftJoinVariableC) {
pattern.expr("Filter using " + predicate, leftJoinVariableA, leftJoinVariableB, leftJoinVariableC,
(a, leftJoinVarA, leftJoinVarB, leftJoinVarC) -> predicate.test(leftJoinVarA, leftJoinVarB, leftJoinVarC,
extract(a)));
return (Child_) this;
}
@Override
public final Child_ filterForJoin(Variable leftJoinVar,
DefaultBiJoiner joiner, JoinerType joinerType, int mappingIndex) {
Function leftMapping = joiner.getLeftMapping(mappingIndex);
Function rightMapping = joiner.getRightMapping(mappingIndex);
Function1 rightExtractor = b -> rightMapping.apply(extract(b));
Predicate2 predicate =
(b, a) -> joinerType.matches(leftMapping.apply(a), rightExtractor.apply(b));
BetaIndex index =
createBetaIndex(joinerType, mappingIndex, leftMapping, rightExtractor);
pattern.expr("Join using joiner #" + mappingIndex + " in " + joiner, leftJoinVar, predicate, index);
return (Child_) this;
}
private BetaIndex createBetaIndex(JoinerType joinerType, int mappingIndex,
Function leftMapping, Function1 rightExtractor) {
if (joinerType == JoinerType.EQUAL) {
return betaIndexedBy(Object.class, getConstraintType(joinerType), mappingIndex, rightExtractor, leftMapping::apply,
Object.class);
} else { // Drools beta index on LT/LTE/GT/GTE requires Comparable.
JoinerType reversedJoinerType = joinerType.flip();
return betaIndexedBy(Comparable.class, getConstraintType(reversedJoinerType), mappingIndex,
c -> (Comparable) rightExtractor.apply(c), leftMapping::apply, Comparable.class);
}
}
@Override
public final Child_ filterForJoin(Variable leftJoinVarA,
Variable leftJoinVarB, DefaultTriJoiner joiner,
JoinerType joinerType, int mappingIndex) {
BiFunction leftMapping = joiner.getLeftMapping(mappingIndex);
Function rightMapping = joiner.getRightMapping(mappingIndex);
Function1 rightExtractor = b -> rightMapping.apply(extract(b));
Predicate3 predicate =
(c, a, b) -> joinerType.matches(leftMapping.apply(a, b), rightExtractor.apply(c));
BetaIndex2 index =
createBetaIndex(joinerType, mappingIndex, leftMapping, rightExtractor);
pattern.expr("Join using joiner #" + mappingIndex + " in " + joiner, leftJoinVarA, leftJoinVarB, predicate, index);
return (Child_) this;
}
private BetaIndex2 createBetaIndex(
JoinerType joinerType,
int mappingIndex, BiFunction leftMapping,
Function1 rightExtractor) {
if (joinerType == JoinerType.EQUAL) {
return betaIndexedBy(Object.class, getConstraintType(joinerType), mappingIndex, rightExtractor, leftMapping::apply,
Object.class);
} else { // Drools beta index on LT/LTE/GT/GTE requires Comparable.
JoinerType reversedJoinerType = joinerType.flip();
return betaIndexedBy(Comparable.class, getConstraintType(reversedJoinerType), mappingIndex,
c -> (Comparable) rightExtractor.apply(c), leftMapping::apply, Comparable.class);
}
}
@Override
public final Child_ filterForJoin(Variable leftJoinVarA,
Variable leftJoinVarB, Variable leftJoinVarC,
DefaultQuadJoiner joiner, JoinerType joinerType,
int mappingIndex) {
TriFunction leftMapping =
joiner.getLeftMapping(mappingIndex);
Function rightMapping = joiner.getRightMapping(mappingIndex);
Function1 rightExtractor = b -> rightMapping.apply(extract(b));
Predicate4 predicate =
(d, a, b, c) -> joinerType.matches(leftMapping.apply(a, b, c), rightExtractor.apply(d));
BetaIndex3 index =
createBetaIndex(joinerType, mappingIndex, leftMapping, rightExtractor);
pattern.expr("Join using joiner #" + mappingIndex + " in " + joiner, leftJoinVarA, leftJoinVarB,
leftJoinVarC, predicate, index);
return (Child_) this;
}
private
BetaIndex3 createBetaIndex(JoinerType joinerType,
int mappingIndex, TriFunction leftMapping,
Function1 rightExtractor) {
if (joinerType == JoinerType.EQUAL) {
return betaIndexedBy(Object.class, getConstraintType(joinerType), mappingIndex, rightExtractor, leftMapping::apply,
Object.class);
} else { // Drools beta index on LT/LTE/GT/GTE requires Comparable.
JoinerType reversedJoinerType = joinerType.flip();
return betaIndexedBy(Comparable.class, getConstraintType(reversedJoinerType), mappingIndex,
c -> (Comparable) rightExtractor.apply(c), leftMapping::apply, Comparable.class);
}
}
@Override
public final Child_ bind(Variable boundVariable, Function bindingFunction) {
pattern.bind(boundVariable, a -> bindingFunction.apply(extract(a)));
return (Child_) this;
}
@Override
public final Child_ bind(Variable boundVariable,
Variable leftJoinVariable, BiFunction bindingFunction) {
pattern.bind(boundVariable, leftJoinVariable,
(a, leftJoinVar) -> bindingFunction.apply(extract(a), leftJoinVar));
return (Child_) this;
}
@Override
public final Child_ bind(Variable boundVariable,
Variable leftJoinVariableA, Variable leftJoinVariableB,
TriFunction bindingFunction) {
pattern.bind(boundVariable, leftJoinVariableA, leftJoinVariableB,
(a, leftJoinVarA, leftJoinVarB) -> bindingFunction.apply(extract(a), leftJoinVarA, leftJoinVarB));
return (Child_) this;
}
@Override
public final Child_ bind(Variable boundVariable,
Variable leftJoinVariableA, Variable leftJoinVariableB,
Variable leftJoinVariableC,
QuadFunction bindingFunction) {
pattern.bind(boundVariable, leftJoinVariableA, leftJoinVariableB, leftJoinVariableC,
(a, leftJoinVarA, leftJoinVarB, leftJoinVarC) -> bindingFunction.apply(extract(a), leftJoinVarA,
leftJoinVarB, leftJoinVarC));
return (Child_) this;
}
@Override
public final List> build() {
Stream> prerequisites = prerequisiteExpressions.stream();
Stream> dependents = dependentExpressions.stream();
return Stream.concat(Stream.concat(prerequisites, Stream.of(pattern)), dependents)
.collect(Collectors.toList());
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy