All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.kie.kogito.trusty.service.common.CounterfactualParameterValidation Maven / Gradle / Ivy

There is a newer version: 2.44.0.Alpha
Show newest version
/*
 * Copyright 2021 Red Hat, Inc. and/or its affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.kie.kogito.trusty.service.common;

import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import org.kie.kogito.explainability.api.CounterfactualSearchDomain;
import org.kie.kogito.explainability.api.CounterfactualSearchDomainValue;
import org.kie.kogito.explainability.api.NamedTypedValue;
import org.kie.kogito.tracing.typedvalue.BaseTypedValue;
import org.kie.kogito.tracing.typedvalue.TypedValue;
import org.kie.kogito.trusty.storage.api.model.decision.DecisionInput;
import org.kie.kogito.trusty.storage.api.model.decision.DecisionOutcome;

public class CounterfactualParameterValidation {

    private CounterfactualParameterValidation() {
        //Prevent instantiation of this utility class
    }

    private interface Check {

        boolean check(Collection structure1,
                Collection structure2);
    }

    /**
     * Checks if two structured parameters are consistent.
     * What constitutes "consistent" is determined by the concrete sub-classes.
     *
     * @param  Type C of one parameter
     * @param  Converter to convert Type C to internal representation
     * @param  Type D of another parameter
     * @param  Converter to convert Type D to internal representation
     */
    private static abstract class BaseCheck implements Check {

        @Override
        public boolean check(Collection structure1, Collection structure2) {
            Collection> normalisedStructure1 = normaliseStructure1(structure1);
            Collection> normalisedStructure2 = normaliseStructure2(structure2);

            return doCheck(normalisedStructure1, normalisedStructure2);
        }

        protected boolean doCheck(Collection> normalisedStructure1,
                Collection> normalisedStructure2) {
            if (Objects.isNull(normalisedStructure1) && Objects.isNull(normalisedStructure2)) {
                return true;
            }
            if (Objects.isNull(normalisedStructure1)) {
                return false;
            }
            if (Objects.isNull(normalisedStructure2)) {
                return false;
            }
            if (normalisedStructure1.isEmpty() && normalisedStructure2.isEmpty()) {
                return true;
            }

            Map> structure1Map = normalisedStructure1.stream().collect(Collectors.toMap(ih -> ih.name, ih -> ih));
            Map> structure2Map = normalisedStructure2.stream().collect(Collectors.toMap(ih -> ih.name, ih -> ih));
            if (!checkMembership(structure1Map, structure2Map)) {
                return false;
            }

            //Check direct descendents
            Collection> structure1ChildStructures =
                    structure1Map.values()
                            .stream()
                            .map(ih -> getChildrenOfStructure1(ih.original))
                            .flatMap(Collection::stream)
                            .collect(Collectors.toList());
            Collection> structure2ChildStructures =
                    structure2Map.values()
                            .stream()
                            .map(ih -> getChildrenOfStructure2(ih.original))
                            .flatMap(Collection::stream)
                            .collect(Collectors.toList());

            return doCheck(structure1ChildStructures, structure2ChildStructures);
        }

        protected Collection> normaliseStructure1(Collection structure1) {
            if (Objects.isNull(structure1)) {
                return null;
            }
            return structure1.stream().map(this::convertStructure1toHolder).collect(Collectors.toList());
        }

        protected Collection> normaliseStructure2(Collection structure2) {
            if (Objects.isNull(structure2)) {
                return null;
            }
            return structure2.stream().map(this::convertStructure2toHolder).collect(Collectors.toList());
        }

        protected abstract StructureHolder convertStructure1toHolder(C value);

        protected abstract StructureHolder convertStructure2toHolder(D value);

        protected abstract Collection> getChildrenOfStructure1(CV value);

        protected abstract Collection> getChildrenOfStructure2(DV value);

        protected abstract boolean checkMembership(Map> structure1Map,
                Map> structure2Map);
    }

    private static class IdenticalCheck extends BaseCheck {

        @Override
        @SuppressWarnings("EqualsBetweenInconvertibleTypes")
        protected boolean checkMembership(Map> structure1Map,
                Map> structure2Map) {
            //Are the maps equal in size?
            boolean validSize = structure2Map.size() == structure1Map.size();
            //Do all members of Structure 1 exist in Structure 2?
            boolean validEntries = structure1Map.entrySet().stream().allMatch(e -> Objects.equals(e.getValue(), structure2Map.get(e.getKey())));
            //If they're equal size and the members are identical the structures must be equal.
            return validSize && validEntries;
        }

        @Override
        protected StructureHolder convertStructure1toHolder(DecisionInput value) {
            return new StructureHolder<>(value.getValue().getKind(),
                    value.getName(),
                    value.getValue().getType(),
                    value.getValue());
        }

        @Override
        protected StructureHolder convertStructure2toHolder(CounterfactualSearchDomain value) {
            return new StructureHolder<>(value.getValue().getKind(),
                    value.getName(),
                    value.getValue().getType(),
                    value.getValue());
        }

        @Override
        protected Collection> getChildrenOfStructure1(TypedValue value) {
            if (value.getKind() != BaseTypedValue.Kind.STRUCTURE) {
                return Collections.emptyList();
            }
            return value.toStructure().getValue()
                    .entrySet()
                    .stream()
                    .map(e -> new StructureHolder(e.getValue().getKind(),
                            e.getKey(),
                            e.getValue().getType(),
                            e.getValue()))
                    .collect(Collectors.toList());
        }

        @Override
        protected Collection> getChildrenOfStructure2(CounterfactualSearchDomainValue value) {
            if (value.getKind() != BaseTypedValue.Kind.STRUCTURE) {
                return Collections.emptyList();
            }
            return value.toStructure().getValue()
                    .entrySet()
                    .stream()
                    .map(e -> new StructureHolder(e.getValue().getKind(),
                            e.getKey(),
                            e.getValue().getType(),
                            e.getValue()))
                    .collect(Collectors.toList());
        }
    }

    private static class SubsetCheck extends BaseCheck {

        @Override
        protected boolean checkMembership(Map> structure1Map,
                Map> structure2Map) {
            //Is the second map at least the size of the first?
            boolean validSize = structure2Map.size() <= structure1Map.size();
            //Do all members of Structure 2 exist in Structure 1?
            boolean validEntries = structure2Map.entrySet().stream().allMatch(e -> Objects.equals(e.getValue(), structure1Map.get(e.getKey())));
            //If Structure 2's size is less than of equal to that of Structure 1 and all members of Structure 2 exist in Structure 1 then Structure 2 must be s subset of Structure 1.
            return validSize && validEntries;
        }

        @Override
        protected StructureHolder convertStructure1toHolder(DecisionOutcome value) {
            return new StructureHolder<>(value.getOutcomeResult().getKind(),
                    value.getOutcomeName(),
                    value.getOutcomeResult().getType(),
                    value.getOutcomeResult());
        }

        @Override
        protected StructureHolder convertStructure2toHolder(NamedTypedValue value) {
            return new StructureHolder<>(value.getValue().getKind(),
                    value.getName(),
                    value.getValue().getType(),
                    value.getValue());
        }

        @Override
        protected Collection> getChildrenOfStructure1(TypedValue value) {
            if (value.getKind() != BaseTypedValue.Kind.STRUCTURE) {
                return Collections.emptyList();
            }
            return value.toStructure().getValue()
                    .entrySet()
                    .stream()
                    .map(e -> new StructureHolder(e.getValue().getKind(),
                            e.getKey(),
                            e.getValue().getType(),
                            e.getValue()))
                    .collect(Collectors.toList());
        }

        @Override
        protected Collection> getChildrenOfStructure2(TypedValue value) {
            return getChildrenOfStructure1(value);
        }
    }

    private static final IdenticalCheck IDENTICAL = new IdenticalCheck();
    private static final SubsetCheck SUBSET = new SubsetCheck();

    /**
     * Checks whether the two structures are identical; irrespective of values.
     *
     * @param inputs Inputs for a Decision
     * @param searchDomains Search Domains for a Counterfactual Explanation
     * @return True if they are identical
     */
    public static boolean isStructureIdentical(Collection inputs, Collection searchDomains) {
        return IDENTICAL.check(inputs, searchDomains);
    }

    /**
     * Checks whether the structure of the Goals is a subset of the structure of the Outcomes; irrespective of values.
     *
     * @param outcomes Outcomes for a Decision
     * @param goals Goals for a Counterfactual Explanation
     * @return True if Goals is a subset of Outcomes
     */
    public static boolean isStructureSubset(Collection outcomes, Collection goals) {
        return SUBSET.check(outcomes, goals);
    }

    private static class StructureHolder {

        private final TypedValue.Kind kind;
        private final String name;
        private final String typeRef;
        private final T original;

        public StructureHolder(BaseTypedValue.Kind kind, String name, String typeRef, T original) {
            this.kind = kind;
            this.name = name;
            this.typeRef = typeRef;
            this.original = original;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            StructureHolder that = (StructureHolder) o;
            return kind == that.kind && Objects.equals(name, that.name) && Objects.equals(typeRef, that.typeRef);
        }

        @Override
        public int hashCode() {
            return Objects.hash(kind, name, typeRef);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy