All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.samediff.flow.FlowPath Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.autodiff.samediff.flow;

import lombok.NonNull;

import java.util.HashMap;
import java.util.Map;

/**
 * This class acts as holder for flow control information.
 *
 * @author [email protected]
 */
public class FlowPath {
    protected Map states = new HashMap<>();
    protected Map frames = new HashMap<>();

    /**
     * This method checks if NodeState was created for specified graph node
     *
     * @param nodeName
     */
    public void ensureNodeStateExists(@NonNull String nodeName) {
        if (!states.containsKey(nodeName))
            states.put(nodeName, new NodeState(nodeName));
    }

    /**
     * This method checks, if specified graph node is active (as in - located within active code branch, and was NOT left in inactive branch)
     *
     * @param nodeName
     * @return
     */
    public boolean isActive(@NonNull String nodeName) {
        ensureNodeStateExists(nodeName);

        return states.get(nodeName).isActive();
    }

    /**
     * This method allows to set specified node active or inactive.
     * PLEASE NOTE: All nodes using this node as input, will be considered inactive, if this node is set to be inactive.
     *
     * @param nodeName
     * @param active
     */
    public void markActive(@NonNull String nodeName, boolean active) {
        ensureNodeStateExists(nodeName);

        states.get(nodeName).setActive(active);
    }

    /**
     * This method sets active/inactive branch for divergent nodes (aka Switch)
     *
     * @param nodeName
     * @param branchIdx
     */
    public void setActiveBranch(@NonNull String nodeName, int branchIdx) {
        states.get(nodeName).setActiveBranch(branchIdx);
    }

    /**
     * This method returns active branch of specific node (if any)
     *
     * @param nodeName
     * @return
     */
    public int getActiveBranch(@NonNull String nodeName) {
        return states.get(nodeName).getActiveBranch();
    }

    /**
     * This method returns TRUE if specified node was already executed during current pass, FALSE otherwise
     * @param nodeName
     * @return
     */
    public boolean wasExecuted(@NonNull String nodeName) {
        ensureNodeStateExists(nodeName);

        return states.get(nodeName).isExecuted();
    }

    /**
     * This method allows to toggle wasExecuted() state for specified node
     * @param nodeName
     * @param executed
     */
    public void markExecuted(@NonNull String nodeName, boolean executed) {

        states.get(nodeName).setExecuted(executed);
    }

    /**
     * This node increments number of iterations by 1.
     *
     * @param nodeName
     */
    public void incrementNumberOfCycles(@NonNull String frameName) {
        frames.get(frameName).incrementNumberOfCycles();
    }

    /**
     * This method returns number of iterations of specified node.
     * @param nodeName
     * @return
     */
    public long getNumberOfCycles(@NonNull String frameName) {
        return states.get(frameName).getNumberOfCycles();
    }

    /**
     * This method adds Frame to tracking
     * PLEASE NOTE: Only works for first call, subsequent calls are no-op
     *
     * @param frame_name
     */
    public void registerFrame(@NonNull String frame_name) {
        if (!frames.containsKey(frame_name))
            frames.put(frame_name, new FrameState(frame_name));
    }

    /**
     * This method removes specified frame from tracking
     *
     * @param frame_name
     */
    // FIXME: this approach is probably bad (for backprop) and should be reconsidered
    public void forgetFrame(@NonNull String frame_name) {
        frames.remove(frame_name);
    }

    /**
     * This method returns TRUE if frame_name already registered, false otherwise
     *
     * @param frame_name
     * @return
     */
    public boolean isRegisteredFrame(@NonNull String frame_name) {
        return frames.containsKey(frame_name);
    }

    /**
     * This method checks, if rewind was planned for specified frame_name
     *
     * @return
     */
    public boolean isRewindPlanned(@NonNull String frameName) {
        return frames.get(frameName).isRewindPlanned();
    }


    public boolean isRewindPossible(@NonNull String frameName) {
        return isRewindPlanned(frameName) && getRewindPosition(frameName) >= 0;
    }

    /**
     * This method announces future rewind of graph execution to specified position
     *
     * @param frameName
     */
    public void planRewind(@NonNull String frameName, boolean reallyPlan) {
        frames.get(frameName).setRewindPlanned(reallyPlan);
    }

    /**
     * This method returns planned position within graph for next rewind.
     *
     * @param frameName
     * @return
     */
    public int getRewindPosition(@NonNull String frameName) {
        return frames.get(frameName).getRewindPosition();
    }

    /**
     * This method allows to set position for next rewind within graph
     *
     * @param frameName
     * @param position
     */
    public void setRewindPosition(@NonNull String frameName, int position) {
        frames.get(frameName).setRewindPosition(position);
    }

    /**
     * This method allows to set position for next rewind within graph.
     * PLEASE NOTE: This methods check, if rewind position wasn't set yet. If it was already set for this frame - it'll be no-op method
     *
     * @param frameName
     * @param position
     */
    public void setRewindPositionOnce(@NonNull String frameName, int position) {
        if (getRewindPosition(frameName) >= 0)
            return;

        frames.get(frameName).setRewindPosition(position);
    }

    /**
     * This method triggers frame state
     *
     * @param frameName
     * @param reallyActivate
     */
    public void activateFrame(@NonNull String frameName, boolean reallyActivate) {
        frames.get(frameName).setActive(reallyActivate);
    }

    /**
     * This method returns TRUE if specified frame was activated (as in: Enter/Merge was triggered)
     *
     * @param frameName
     * @return
     */
    public boolean isFrameActive(@NonNull String frameName) {
        return frames.get(frameName).isActive();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy