All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.solvers.VectorizedBackTrackLineSearchMinimum Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.linalg.solvers;

/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org.  For further
information, see the file `LICENSE' included with this distribution. */


/**
 @author Aron Culotta [email protected]

 Adapted from mallet with original authors above.
 Modified to be a vectorized version that uses jblas matrices
 for computation rather than the mallet ops.

 */

/**
 Numerical Recipes in C: p.385. lnsrch. A simple backtracking line
 search. No attempt at accurately finding the true minimum is
 made. The goal is only to ensure that BackTrackLineSearch will
 return a position of higher value.

 @author Adam Gibson


 */

import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.solvers.api.OptimizableByGradientValueMatrix;
import org.nd4j.linalg.util.LinAlgExceptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


//"Line Searches and Backtracking", p385, "Numeric Recipes in C"

public class VectorizedBackTrackLineSearchMinimum
{
    private static Logger logger = LoggerFactory.getLogger(VectorizedBackTrackLineSearchMinimum.class.getName());

    OptimizableByGradientValueMatrix function;

    public VectorizedBackTrackLineSearchMinimum(OptimizableByGradientValueMatrix optimizable) {
        this.function = optimizable;
    }

    final int maxIterations = 100;
    double stpmax = 100;
    final double EPS = 3.0e-12;

    // termination conditions: either
    //   a) abs(delta x/x) < REL_TOLX for all coordinates
    //   b) abs(delta x) < ABS_TOLX for all coordinates
    //   c) sufficient function increase (uses ALF)
    private double relTolx = 1e-10;
    private double absTolx = 1e-4; // tolerance on absolute value difference
    final double ALF = 1e-4;

    public void setStpmax(double stpmax) {
        this.stpmax = stpmax;
    }


    public double getStpmax() {
        return stpmax;
    }

    /**
     * Sets the tolerance of relative diff in function value.
     *  Line search converges if abs(delta x / x) < tolx
     *  for all coordinates. */
    public void setRelTolx (double tolx) { relTolx = tolx; }

    /**
     * Sets the tolerance of absolute diff in function value.
     *  Line search converges if abs(delta x) < tolx
     *  for all coordinates. */
    public void setAbsTolx (double tolx) { absTolx = tolx; }

    // initialStep is ignored.  This is b/c if the initial step is not 1.0,
    //   it sometimes confuses the backtracking for reasons I don't
    //   understand.  (That is, the jump gets LARGER on iteration 1.)

    // returns fraction of step size (alam) if found a good step
    // returns 0.0 if could not step in direction
    public double optimize (INDArray line, int lineSearchIteration,INDArray g,INDArray params)
    {
        INDArray x, oldParameters;
        double slope, test, alamin, alam, alam2, tmplam;
        double rhs1, rhs2, a, b, disc, oldAlam;
        double f, fold, f2;
        oldParameters = params.dup();


        alam2 = tmplam = 0.0;
        f2 = fold = function.getValue();
        if (logger.isDebugEnabled()) {
            logger.trace ("ENTERING BACKTRACK\n");
            logger.trace("Entering BackTrackLnSrch, value=" + fold +",\ndirection.oneNorm:"
                    +	line.norm1(Integer.MAX_VALUE) + "  direction.infNorm:"+ FastMath.max(Double.NEGATIVE_INFINITY, (double) Transforms.abs(line).max(Integer.MAX_VALUE).element()));
        }

        LinAlgExceptions.assertValidNum(g);
        double sum = (double) line.norm2(Integer.MAX_VALUE).element();
//        if(sum > stpmax) {
//            logger.warn("attempted step too big. scaling: sum= " + sum +
//                    ", stpmax= "+ stpmax);
//            line.muli(stpmax / sum);
//        }

        //dot product
        slope = Nd4j.getBlasWrapper().dot(g, line);
        logger.debug("slope = " + slope);

        // find maximum lambda
        // converge when (delta x) / x < REL_TOLX for all coordinates.
        //  the largest step size that triggers this threshold is
        //  precomputed and saved in alamin
        INDArray maxOldParams = Nd4j.create(line.length());
        for(int i = 0;i < line.length(); i++) {
            maxOldParams.putScalar(i,Math.max(Math.abs((double) oldParameters.getScalar(i).element()), 1.0));

        }

        INDArray testMatrix = Transforms.abs(line).div(maxOldParams);

        test = (double) testMatrix.max(Integer.MAX_VALUE).element();


        alamin = relTolx / test;

        alam  = 1.0;
        oldAlam = 0.0;
        int iteration = 0;
        // look for step size in direction given by "line"
        for(iteration = 0; iteration < maxIterations; iteration++) {
            function.setCurrentIteration(lineSearchIteration);
            // x = oldParameters + alam*line
            // initially, alam = 1.0, i.e. take full Newton step
            logger.trace("BackTrack loop iteration " + iteration +" : alam="+
                    alam+" oldAlam=" + oldAlam);
            logger.trace ("before step, x.1norm: " + params.norm1(Integer.MAX_VALUE) +
                    "\nalam: " + alam + "\noldAlam: " + oldAlam);
            assert(alam != oldAlam) : "alam == oldAlam";



            params.addi(line.mul(alam - oldAlam));  // step

            logger.debug ("after step, x.1norm: " + params.norm1(Integer.MAX_VALUE));

            // check for convergence
            //convergence on delta x
            if ((alam < alamin) || smallAbsDiff (oldParameters, params)) {
                //				if ((alam < alamin)) {
                function.setParameters(oldParameters);
                f = function.getValue();
                logger.trace("EXITING BACKTRACK: Jump too small (alamin="+ alamin + "). Exiting and using xold. Value=" + f);
                return 0.0;
            }

            function.setParameters(params);
            oldAlam = alam;
            f = function.getValue();

            logger.debug("value = " + f);

            // sufficient function increase (Wolf condition)
            if(f >= fold + ALF * alam * slope) {

                logger.debug("EXITING BACKTRACK: value=" + f);

                if (f= .1*Lambda_1

        }

        if(iteration >= maxIterations)
            throw new IllegalStateException ("Too many iterations.");
        return 0.0;
    }

    // returns true iff we've converged based on absolute x difference
    private boolean smallAbsDiff (INDArray x, INDArray xold)
    {

        for (int i = 0; i < x.length(); i++) {
            double comp = Math.abs ((double) x.getScalar(i).sub(xold.getScalar(i)).element());
            if ( comp > absTolx) {
                return false;
            }
        }
        return true;
    }

}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy