All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.samediff.internal.AbstractDependencyTracker Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.autodiff.samediff.internal;

import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.function.Predicate;
import org.nd4j.common.primitives.Pair;

import java.util.*;

@Slf4j
public abstract class AbstractDependencyTracker {
    @Getter
    private final Map> dependencies;                              //Key: the dependent. Value: all things that the key depends on
    @Getter
    private final Map>> orDependencies;                    //Key: the dependent. Value: the set of OR dependencies
    @Getter
    private final Map> reverseDependencies = new HashMap<>();     //Key: the dependee. Value: The set of all dependents that depend on this value
    @Getter
    private final Map> reverseOrDependencies = new HashMap<>();
    @Getter
    private final Set satisfiedDependencies = new HashSet<>();           //Mark the dependency as satisfied. If not in set: assumed to not be satisfied
    @Getter
    private final Set allSatisfied;                                      //Set of all dependent values (Ys) that have all dependencies satisfied
    @Getter
    private final Queue allSatisfiedQueue = new LinkedList<>();          //Queue for *new* "all satisfied" values. Values are removed using the "new all satisfied" methods


    protected AbstractDependencyTracker() {
        dependencies = (Map>) newTMap();
        orDependencies = (Map>>) newTMap();
        allSatisfied = newTSet();
    }

    /**
     * @return A new map where the dependents (i.e., Y in "X -> Y") are the key
     */
    protected abstract Map newTMap();

    /**
     * @return A new set where the dependents (i.e., Y in "X -> Y") are the key
     */
    protected abstract Set newTSet();

    /**
     * @return A String representation of the dependent object
     */
    protected abstract String toStringT(T t);

    /**
     * @return A String representation of the dependee object
     */
    protected abstract String toStringD(D d);

    /**
     * Clear all internal state for the dependency tracker
     */
    public void clear() {
        dependencies.clear();
        orDependencies.clear();
        reverseDependencies.clear();
        reverseOrDependencies.clear();
        satisfiedDependencies.clear();
        allSatisfied.clear();
        allSatisfiedQueue.clear();
    }

    /**
     * @return True if no dependencies have been defined
     */
    public boolean isEmpty() {
        return dependencies.isEmpty() && orDependencies.isEmpty() &&
                allSatisfiedQueue.isEmpty();
    }

    /**
     * @return True if the dependency has been marked as satisfied using {@link #markSatisfied(Object, boolean)}
     */
    public boolean isSatisfied(@NonNull D x) {
        return satisfiedDependencies.contains(x);
    }

    /**
     * Mark the specified value as satisfied.
     * For example, if two dependencies have been previously added (X -> Y) and (X -> A) then after the markSatisfied(X, true)
     * call, both of these dependencies are considered satisfied.
     *
     * @param x         Value to mark
     * @param satisfied Whether to mark as satisfied (true) or unsatisfied (false)
     */
    public void markSatisfied(@NonNull D x, boolean satisfied) {
        if (satisfied) {
            boolean alreadySatisfied = satisfiedDependencies.contains(x);

            if (!alreadySatisfied) {
                satisfiedDependencies.add(x);

                //Check if any Y's exist that have dependencies that are all satisfied, for X -> Y
                Set s = reverseDependencies.get(x);
                Set s2 = reverseOrDependencies.get(x);

                Set set;
                if (s != null && s2 != null) {
                    set = newTSet();
                    set.addAll(s);
                    set.addAll(s2);
                } else if (s != null) {
                    set = s;
                } else if (s2 != null) {
                    set = s2;
                } else {
                    if (log.isTraceEnabled()) {
                        log.trace("No values depend on: {}", toStringD(x));
                    }
                    return;
                }

                for (T t : set) {
                    Set required = dependencies.get(t);
                    Set> requiredOr = orDependencies.get(t);
                    boolean allSatisfied = true;
                    if (required != null) {
                        for (D d : required) {
                            if (!isSatisfied(d)) {
                                allSatisfied = false;
                                break;
                            }
                        }
                    }
                    if (allSatisfied && requiredOr != null) {
                        for (Pair p : requiredOr) {
                            if (!isSatisfied(p.getFirst()) && !isSatisfied(p.getSecond())) {
                                allSatisfied = false;
                                break;
                            }
                        }
                    }

                    if (allSatisfied && !this.allSatisfied.contains(t)) {
                        this.allSatisfied.add(t);
                        this.allSatisfiedQueue.add(t);
                    }
                }
            }

        } else {
            satisfiedDependencies.remove(x);
            if (!allSatisfied.isEmpty()) {

                Set reverse = reverseDependencies.get(x);
                if (reverse != null) {
                    for (T y : reverse) {
                        if (allSatisfied.contains(y)) {
                            allSatisfied.remove(y);
                            allSatisfiedQueue.remove(y);
                        }
                    }
                }
                Set orReverse = reverseOrDependencies.get(x);
                if (orReverse != null) {
                    for (T y : orReverse) {
                        if (allSatisfied.contains(y) && !isAllSatisfied(y)) {
                            allSatisfied.remove(y);
                            allSatisfiedQueue.remove(y);
                        }
                    }
                }
            }
        }
    }

    /**
     * Check whether any dependencies x -> y exist, for y (i.e., anything previously added by {@link #addDependency(Object, Object)}
     * or {@link #addOrDependency(Object, Object, Object)}
     *
     * @param y Dependent to check
     * @return True if Y depends on any values
     */
    public boolean hasDependency(@NonNull T y) {
        Set s1 = dependencies.get(y);
        if (s1 != null && !s1.isEmpty())
            return true;

        Set> s2 = orDependencies.get(y);
        return s2 != null && !s2.isEmpty();
    }

    /**
     * Get all dependencies x, for x -> y, and (x1 or x2) -> y
     *
     * @param y Dependent to get dependencies for
     * @return List of dependencies
     */
    public DependencyList getDependencies(@NonNull T y) {
        Set s1 = dependencies.get(y);
        Set> s2 = orDependencies.get(y);

        List l1 = (s1 == null ? null : new ArrayList<>(s1));
        List> l2 = (s2 == null ? null : new ArrayList<>(s2));

        return new DependencyList<>(y, l1, l2);
    }

    /**
     * Add a dependency: y depends on x, as in x -> y
     *
     * @param y The dependent
     * @param x The dependee that is required for Y
     */
    public void addDependency(@NonNull T y, @NonNull D x) {
        if (!dependencies.containsKey(y))
            dependencies.put(y, new HashSet());

        if (!reverseDependencies.containsKey(x))
            reverseDependencies.put(x, newTSet());

        dependencies.get(y).add(x);
        reverseDependencies.get(x).add(y);

        checkAndUpdateIfAllSatisfied(y);
    }

    protected void checkAndUpdateIfAllSatisfied(@NonNull T y) {
        boolean allSat = isAllSatisfied(y);
        if (allSat) {
            //Case where "x is satisfied" happened before x->y added
            if (!allSatisfied.contains(y)) {
                allSatisfied.add(y);
                allSatisfiedQueue.add(y);
            }
        } else if (allSatisfied.contains(y)) {
            if (!allSatisfiedQueue.contains(y)) {
                StringBuilder sb = new StringBuilder();
                sb.append("Dependent object \"").append(toStringT(y)).append("\" was previously processed after all dependencies")
                        .append(" were marked satisfied, but is now additional dependencies have been added.\n");
                DependencyList dl = getDependencies(y);
                if (dl.getDependencies() != null) {
                    sb.append("Dependencies:\n");
                    for (D d : dl.getDependencies()) {
                        sb.append(d).append(" - ").append(isSatisfied(d) ? "Satisfied" : "Not satisfied").append("\n");
                    }
                }
                if (dl.getOrDependencies() != null) {
                    sb.append("Or dependencies:\n");
                    for (Pair p : dl.getOrDependencies()) {
                        sb.append(p).append(" - satisfied=(").append(isSatisfied(p.getFirst())).append(",").append(isSatisfied(p.getSecond())).append(")");
                    }
                }
                throw new IllegalStateException(sb.toString());
            }

            //Not satisfied, but is in the queue -> needs to be removed
            allSatisfied.remove(y);
            allSatisfiedQueue.remove(y);
        }
    }

    protected boolean isAllSatisfied(@NonNull T y) {
        Set set1 = dependencies.get(y);

        boolean retVal = true;
        if (set1 != null) {
            for (D d : set1) {
                retVal = isSatisfied(d);
                if (!retVal)
                    break;
            }
        }
        if (retVal) {
            Set> set2 = orDependencies.get(y);
            if (set2 != null) {
                for (Pair p : set2) {
                    retVal = isSatisfied(p.getFirst()) || isSatisfied(p.getSecond());
                    if (!retVal)
                        break;
                }
            }
        }
        return retVal;
    }


    /**
     * Remove a dependency (x -> y)
     *
     * @param y The dependent that currently requires X
     * @param x The dependee that is no longer required for Y
     */
    public void removeDependency(@NonNull T y, @NonNull D x) {
        if (!dependencies.containsKey(y) && !orDependencies.containsKey(y))
            return;

        Set s = dependencies.get(y);
        if (s != null) {
            s.remove(x);
            if (s.isEmpty())
                dependencies.remove(y);
        }

        Set s2 = reverseDependencies.get(x);
        if (s2 != null) {
            s2.remove(y);
            if (s2.isEmpty())
                reverseDependencies.remove(x);
        }


        Set> s3 = orDependencies.get(y);
        if (s3 != null) {
            boolean removedReverse = false;
            Iterator> iter = s3.iterator();
            while (iter.hasNext()) {
                Pair p = iter.next();
                if (x.equals(p.getFirst()) || x.equals(p.getSecond())) {
                    iter.remove();

                    if (!removedReverse) {
                        Set set1 = reverseOrDependencies.get(p.getFirst());
                        Set set2 = reverseOrDependencies.get(p.getSecond());

                        set1.remove(y);
                        set2.remove(y);

                        if (set1.isEmpty())
                            reverseOrDependencies.remove(p.getFirst());
                        if (set2.isEmpty())
                            reverseOrDependencies.remove(p.getSecond());

                        removedReverse = true;
                    }
                }
            }
        }
        if (s3 != null && s3.isEmpty())
            orDependencies.remove(y);
    }

    /**
     * Add an "Or" dependency: Y requires either x1 OR x2 - i.e., (x1 or x2) -> Y
* If either x1 or x2 (or both) are marked satisfied via {@link #markSatisfied(Object, boolean)} then the * dependency is considered satisfied * * @param y Dependent * @param x1 Dependee 1 * @param x2 Dependee 2 */ public void addOrDependency(@NonNull T y, @NonNull D x1, @NonNull D x2) { if (!orDependencies.containsKey(y)) orDependencies.put(y, new HashSet>()); if (!reverseOrDependencies.containsKey(x1)) reverseOrDependencies.put(x1, newTSet()); if (!reverseOrDependencies.containsKey(x2)) reverseOrDependencies.put(x2, newTSet()); orDependencies.get(y).add(new Pair<>(x1, x2)); reverseOrDependencies.get(x1).add(y); reverseOrDependencies.get(x2).add(y); checkAndUpdateIfAllSatisfied(y); } /** * @return True if there are any new/unprocessed "all satisfied dependents" (Ys in X->Y) */ public boolean hasNewAllSatisfied() { return !allSatisfiedQueue.isEmpty(); } /** * Returns the next new dependent (Y in X->Y) that has all dependees (Xs) marked as satisfied via {@link #markSatisfied(Object, boolean)} * Throws an exception if {@link #hasNewAllSatisfied()} returns false.
* Note that once a value has been retrieved from here, no new dependencies of the form (X -> Y) can be added for this value; * the value is considered "processed" at this point. * * @return The next new "all satisfied dependent" */ public T getNewAllSatisfied() { Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied"); return allSatisfiedQueue.remove(); } /** * @return As per {@link #getNewAllSatisfied()} but returns all values */ public List getNewAllSatisfiedList() { Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied"); List ret = new ArrayList<>(allSatisfiedQueue); allSatisfiedQueue.clear(); return ret; } /** * As per {@link #getNewAllSatisfied()} but instead of returning the first dependee, it returns the first that matches * the provided predicate. If no value matches the predicate, null is returned * * @param predicate Predicate gor checking * @return The first value matching the predicate, or null if no values match the predicate */ public T getFirstNewAllSatisfiedMatching(@NonNull Predicate predicate) { Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied"); T t = allSatisfiedQueue.peek(); if (predicate.test(t)) { t = allSatisfiedQueue.remove(); allSatisfied.remove(t); return t; } if (allSatisfiedQueue.size() > 1) { Iterator iter = allSatisfiedQueue.iterator(); while (iter.hasNext()) { t = iter.next(); if (predicate.test(t)) { iter.remove(); allSatisfied.remove(t); return t; } } } return null; //None match predicate } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy