All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gov.sandia.cognition.learning.algorithm.minimization.matrix.IterativeMatrixSolver Maven / Gradle / Ivy

There is a newer version: 4.0.1
Show newest version
/*
 * File:                IterativeMatrixSolver.java
 * Authors:             Jeremy D. Wendt
 * Company:             Sandia National Laboratories
 * Project:             Cognitive Foundry
 * 
 * Copyright 2016, Sandia Corporation.
 * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
 * license for use of this work by or on behalf of the U.S. Government. 
 * Export of this program may require a license from the United States
 * Government. See CopyrightHistory.txt for complete details.
 */

package gov.sandia.cognition.learning.algorithm.minimization.matrix;

import gov.sandia.cognition.algorithm.IterativeAlgorithmListener;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.minimization.FunctionMinimizer;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.util.CloneableSerializable;
import java.util.HashSet;
import java.util.Set;

/**
 * Base class for all iterative matrix solvers that takes care of most of the
 * basic iterative logic and the function minimizer interface.
 * 
 * @author Jeremy D. Wendt
 * @since 4.0.0
 * @param  The operator for the solver.
 */
@PublicationReference(author = "Jonathan Richard Shewchuk",
    title = "An Introduction to the Conjugate Gradient Method Without the Agonizing Pain",
    type = PublicationType.WebPage,
    year = 1994,
    url = "http://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf‎")
abstract public class IterativeMatrixSolver
    implements FunctionMinimizer
{

    /**
     * The tolerance of the error accepted before stopping iterations.
     */
    protected double tolerance;

    /**
     * The initial guess for the left-hand-side vector (x).
     */
    protected Vector x0;

    /**
     * The right-hand-side vector (b).
     */
    protected Vector rhs;

    /**
     * Execution will stop after this number of iterations even if it has not
     * converged.
     */
    protected int maxIterations;

    /**
     * Listeners to the algorithms progress have the opportunity to stop the
     * algorithm after a specified number of iterations.
     */
    protected Set listeners;

    /**
     * Counts the number of iterations executed thus far.
     */
    protected int iterationCounter;

    /**
     * If set to true, the algorithm will stop after the current iteration
     * completes.
     */
    protected boolean shouldStop;

    /**
     * Stores the input rhs vector and the resulting x vector from the most
     * recent "learn" call.
     */
    private InputOutputPair result;

    /**
     * Unsupported null constructor.
     *
     * @throws UnsupportedOperationException
     */
    private IterativeMatrixSolver()
    {
        throw new UnsupportedOperationException("Do not call this method.");
    }

    /**
     * Initializes a solver with basic necessary values
     *
     * @param x0 The initial guess for x
     * @param rhs The "b" to solve
     */
    protected IterativeMatrixSolver(Vector x0,
        Vector rhs)
    {
        this(x0, rhs, 1e-10, x0.getDimensionality() * 10);
    }

    /**
     * Initializes a solver with a few more values
     *
     * @param x0 The initial guess for x
     * @param rhs The "b" to solve
     * @param tolerance The minimum acceptable error
     */
    protected IterativeMatrixSolver(Vector x0,
        Vector rhs,
        double tolerance)
    {
        this(x0, rhs, tolerance, x0.getDimensionality() * 10);
    }

    /**
     * Inititalizes a solver with all user-definable parameters
     *
     * @param x0 The initial guess for x
     * @param rhs The "b" to solve
     * @param tolerance The minimum acceptable error
     * @param maxIterations The maximum number of iterations
     */
    protected IterativeMatrixSolver(Vector x0,
        Vector rhs,
        double tolerance,
        int maxIterations)
    {
        this.x0 = x0.clone();
        this.rhs = rhs.clone();
        setTolerance(tolerance);
        setMaxIterations(maxIterations);
        listeners = new HashSet();
        iterationCounter = -1;
        shouldStop = false;
        result = null;
    }

    /**
     * Protected copy constructor
     *
     * @param copy The "self" to copy
     */
    @SuppressWarnings("unchecked")
    protected IterativeMatrixSolver(IterativeMatrixSolver copy)
    {
        this.x0 = copy.x0;
        this.rhs = copy.rhs;
        this.setTolerance(copy.tolerance);
        this.setMaxIterations(copy.maxIterations);
        this.listeners = copy.listeners;
        this.iterationCounter = copy.iterationCounter;
        this.shouldStop = copy.shouldStop;
        this.result = copy.result;
    }

    /**
     * Shell that solves for Ax = b (x0 and rhs passed in on initialization, A
     * is contained in function).
     *
     * @param function Matrix wrapper
     * @return The input b and resulting x found.
     */
    @Override
    final public InputOutputPair learn(
        Operator function)
    {
        if (!function.canEvaluateAgainst(x0, rhs))
        {
            throw new IllegalArgumentException("Input matrix solves for a "
                + "dimensionality than the input x0 and rhs");
        }
        iterationCounter = 0;
        shouldStop = false;
        result = null;
        for (IterativeAlgorithmListener listener : listeners)
        {
            listener.algorithmStarted(this);
        }
        initializeSolver(function);
        while ((!shouldStop) && (iterationCounter < maxIterations))
        {
            ++iterationCounter;
            for (IterativeAlgorithmListener listener : listeners)
            {
                listener.stepStarted(this);
            }
            double residual = iterate();
            for (IterativeAlgorithmListener listener : listeners)
            {
                listener.stepEnded(this);
            }
            if (residual < tolerance)
            {
                break;
            }
        }
        result = completeSolver();
        for (IterativeAlgorithmListener listener : listeners)
        {
            listener.algorithmEnded(this);
        }

        return result;
    }

    /**
     * Called before iterations begin in learn. Iterative solvers can solve for
     * initial state and should store function away.
     *
     * @param function The matrix wrapper to save for iterate.
     */
    abstract protected void initializeSolver(Operator function);

    /**
     * Called during each step of the iterative solver. Take one step forward in
     * the algorithm.
     *
     * @return the residual after this step.
     */
    abstract protected double iterate();

    /**
     * Called after the final iteration. The solver should clean up any
     * intermediate results and return the final results.
     *
     * @return the final results of the algorithm.
     */
    abstract protected InputOutputPair completeSolver();

    /**
     * @see FunctionMinimizer#clone()
     */
    @Override
    abstract public CloneableSerializable clone();

    /**
     * @see FunctionMinimizer#getTolerance()
     */
    @Override
    final public double getTolerance()
    {
        return tolerance;
    }

    /**
     * Sets the minimum tolerance before iterations complete (must be
     * non-negative). If set to zero, you'll likely go all iterations (to
     * maxIterations) in most cases due to numerical precision issues.
     *
     * @param tolerance The minimum tolerance acceptable before returning the
     * result.
     */
    @Override
    final public void setTolerance(double tolerance)
    {
        if (tolerance < 0)
        {
            throw new IllegalArgumentException("Tolerance must be non-negative.");
        }
        this.tolerance = tolerance;
    }

    /**
     * Returns the initial guess at "x"
     *
     * @return the initial guess at "x"
     */
    @Override
    final public Vector getInitialGuess()
    {
        return x0.clone();
    }

    /**
     * Sets the initial guess ("x0")
     *
     * @param initialGuess the initial guess ("x0")
     */
    @Override
    final public void setInitialGuess(Vector initialGuess)
    {
        x0 = initialGuess.clone();
    }

    /**
     * @see FunctionMinimizer#getMaxIterations()
     */
    @Override
    final public int getMaxIterations()
    {
        return maxIterations;
    }

    /**
     * Sets the maximum number of iterations before this will stop iterating. It
     * will stop sooner if the residual is below the minimum residual. The
     * number of iterations must be positive (>0).
     *
     * @param maxIterations The maximum number of iterations
     */
    @Override
    final public void setMaxIterations(int maxIterations)
    {
        if (maxIterations <= 0)
        {
            throw new IllegalArgumentException("Max iterations must be positive");
        }
        this.maxIterations = maxIterations;
    }

    /**
     * @see FunctionMinimizer#getResult()
     */
    @Override
    public InputOutputPair getResult()
    {
        return result;
    }

    /**
     * @see FunctionMinimizer#getIteration()
     */
    @Override
    public int getIteration()
    {
        return iterationCounter;
    }

    /**
     * @see
     * FunctionMinimizer#addIterativeAlgorithmListener(gov.sandia.cognition.algorithm.IterativeAlgorithmListener)
     */
    @Override
    final public void addIterativeAlgorithmListener(
        IterativeAlgorithmListener listener)
    {
        listeners.add(listener);
    }

    /**
     * @see
     * FunctionMinimizer#removeIterativeAlgorithmListener(gov.sandia.cognition.algorithm.IterativeAlgorithmListener)
     */
    @Override
    final public void removeIterativeAlgorithmListener(
        IterativeAlgorithmListener listener)
    {
        listeners.remove(listener);
    }

    /**
     * Execution will stop after the current iteration completes.
     */
    @Override
    public void stop()
    {
        shouldStop = true;
    }

    /**
     * Returns true if execution stopped because the residual was below the
     * acceptable tolerance (vs. due to stop being called or exceeding
     * maxIterations).
     *
     * @return true if execution stopped because the residual was below
     * acceptable tolerance.
     */
    @Override
    final public boolean isResultValid()
    {
        // If it wasn't stopped early, the result is below tolerance
        return (!shouldStop) && (iterationCounter < maxIterations);
    }

    @Override
    public boolean equals(Object o)
    {
        if (!(o instanceof IterativeMatrixSolver))
        {
            return false;
        }
        IterativeMatrixSolver other = (IterativeMatrixSolver) o;
        if (tolerance != other.tolerance)
        {
            return false;
        }
        else if ((x0 == null) && (other.x0 != null))
        {
            return false;
        }
        else if ((x0 != null) && !x0.equals(other.x0))
        {
            return false;
        }
        else if ((rhs == null) && (other.rhs != null))
        {
            return false;
        }
        else if ((rhs != null) && !rhs.equals(other.rhs))
        {
            return false;
        }
        else if (maxIterations != other.maxIterations)
        {
            return false;
        }
        else if ((listeners == null) && (other.listeners != null))
        {
            return false;
        }
        else if ((listeners != null) && !listeners.equals(other.listeners))
        {
            return false;
        }
        else if (iterationCounter != other.iterationCounter)
        {
            return false;
        }
        else if (shouldStop != other.shouldStop)
        {
            return false;
        }
        else if ((result == null) && (other.result != null))
        {
            return false;
        }
        else if ((result != null) && !result.equals(other.result))
        {
            return false;
        }
        return true;
    }

    /**
     * @see Object#hashCode()
     */
    @Override
    public int hashCode()
    {
        int hash = 1;
        hash = hash * 17
            + Long.valueOf(Double.doubleToLongBits(tolerance)).hashCode();
        hash = hash * 17 + ((x0 == null) ? 0 : x0.hashCode());
        hash = hash * 17 + ((rhs == null) ? 0 : rhs.hashCode());
        hash = hash * 17 + Long.valueOf(maxIterations).hashCode();
        hash = hash * 17 + ((listeners == null) ? 0 : listeners.hashCode());
        hash = hash * 17 + Long.valueOf(iterationCounter).hashCode();
        hash = hash * 17 + Boolean.valueOf(shouldStop).hashCode();
        hash = hash * 17 + ((result == null) ? 0 : result.hashCode());

        return hash;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy