All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.common.ConversionState Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.common;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public final class ConversionState {
    private static final Log logger = LogFactory.getLog(ConversionState.class);

    /** Whether to check conversion */
    private final boolean conversionCheck;
    /** Threshold to determine convergence */
    private final double convergenceRate;

    /** being ready to end iteration */
    private boolean readyToFinishIterations;

    /** The cumulative errors in the training */
    private double totalErrors;
    /** The cumulative losses in an iteration */
    private double currLosses, prevLosses;

    private int curIter;

    public ConversionState() {
        this(true, 0.005d);
    }

    public ConversionState(boolean conversionCheck, double convergenceRate) {
        this.conversionCheck = conversionCheck;
        this.convergenceRate = convergenceRate;
        this.readyToFinishIterations = false;
        this.totalErrors = 0.d;
        this.currLosses = 0.d;
        this.prevLosses = Double.POSITIVE_INFINITY;
        this.curIter = 1;
    }

    public double getTotalErrors() {
        return totalErrors;
    }

    public double getCumulativeLoss() {
        return currLosses;
    }

    public double getAverageLoss(@Nonnegative final long numInstances) {
        if (numInstances == 0) {
            return 0.d;
        }
        return currLosses / numInstances;
    }

    public double getPreviousLoss() {
        return prevLosses;
    }

    public void incrError(double error) {
        this.totalErrors += error;
    }

    public void incrLoss(double loss) {
        this.currLosses += loss;
    }

    public void multiplyLoss(double multi) {
        this.currLosses = currLosses * multi;
    }

    public boolean isLossIncreased() {
        return currLosses > prevLosses;
    }

    public boolean isConverged(final long observedTrainingExamples) {
        if (conversionCheck == false) {
            return false;
        }

        if (currLosses > prevLosses) {
            if (logger.isInfoEnabled()) {
                logger.info("Iteration #" + curIter + " current cumulative loss `" + currLosses
                        + "` > previous cumulative loss `" + prevLosses + '`');
            }
            this.readyToFinishIterations = false;
            return false;
        }

        final double changeRate = getChangeRate();
        if (changeRate < convergenceRate) {
            if (readyToFinishIterations) {
                // NOTE: never be true at the first iteration where prevLosses == Double.POSITIVE_INFINITY
                if (logger.isInfoEnabled()) {
                    logger.info("Training converged at " + curIter + "-th iteration!\n"
                            + getInfo(observedTrainingExamples));
                }
                return true;
            } else {
                if (logger.isInfoEnabled()) {
                    logger.info(getInfo(observedTrainingExamples));
                }
                this.readyToFinishIterations = true;
            }
        } else {
            if (logger.isInfoEnabled()) {
                logger.info(getInfo(observedTrainingExamples));
            }
            this.readyToFinishIterations = false;
        }

        return false;
    }

    double getChangeRate() {
        return (prevLosses - currLosses) / prevLosses;
    }

    public void next() {
        this.prevLosses = currLosses;
        this.currLosses = 0.d;
        this.curIter++;
    }

    public int getCurrentIteration() {
        return curIter;
    }

    @Nonnull
    public String getInfo(@Nonnegative final long observedTrainingExamples) {
        final StringBuilder buf = new StringBuilder();
        buf.append("Iteration #").append(curIter).append(" | ");
        buf.append("average loss=").append(getAverageLoss(observedTrainingExamples)).append(", ");
        buf.append("current cumulative loss=").append(currLosses).append(", ");
        buf.append("previous cumulative loss=").append(prevLosses).append(", ");
        buf.append("change rate=").append(getChangeRate()).append(", ");
        buf.append("#trainingExamples=").append(observedTrainingExamples);
        return buf.toString();
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy