All Downloads are FREE. Search and download functionalities are using the official Maven repository.

weka.classifiers.timeseries.WekaForecaster Maven / Gradle / Ivy

/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see .
 */


/*
 *    WekaForecaster.java
 *    Copyright (C) 2010-2016 University of Waikato, Hamilton, New Zealand
 */

package weka.classifiers.timeseries;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.NumericPrediction;
import weka.classifiers.functions.LinearRegression;
import weka.classifiers.timeseries.core.*;
import weka.filters.supervised.attribute.TSLagMaker;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Utils;
import weka.core.logging.Logger;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;
import weka.filters.unsupervised.attribute.RemoveType;

import java.io.PrintStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Map;
import java.util.Vector;

/**
 * Class that implements time series forecasting using a Weka regression scheme.
 * Makes use of the TSLagMaker class to handle all lagged attribute creation,
 * periodic attributes etc.
 * 
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision: 52593 $
 */
public class WekaForecaster extends AbstractForecaster implements TSLagUser,
  ConfidenceIntervalForecaster, OverlayForecaster, IncrementallyPrimeable,
  OptionHandler, Serializable {

  /** For serialization */
  private static final long serialVersionUID = 5562710925011828590L;

  /** The format of the original incoming instances */
  protected Instances m_originalHeader;

  /**
   * A temporary header used when updating base learners that implement
   * PrimingDataLearner
   */
  protected Instances m_tempHeader;

  /** A copy of the input data provided to primeForecaster() */
  protected transient Instances m_primedInput;

  /** The format of the transformed data */
  protected Instances m_transformedHeader;

  /** The base regression scheme to use */
  protected Classifier m_forecaster = new LinearRegression();

  /** The individual forecasters for each target */
  protected List m_singleTargetForecasters;

  /** True if the forecaster has been built */
  protected boolean m_modelBuilt = false;

  /** True if an artificial time index has been added to the data */
  protected boolean m_useArtificialTimeIndex = false;

  /**
   * The estimator used for calculating confidence limits.
   */
  protected ErrorBasedConfidenceIntervalEstimator m_confidenceLimitEstimator;

  /**
   * Number of steps ahead to calculate confidence limits for (0 = don't
   * calculate confidence limits
   */
  protected int m_calculateConfLimitsSteps = 0;

  /** Confidence level to compute confidence limits at */
  protected double m_confidenceLevel = 0.95;
  /**
   * For removing any date attributes (TSLagMaker will remap date timestamps to
   * numeric)
   */
  protected RemoveType m_dateRemover;
  /**
   * Holds a list of training instance indexes that contained missing target
   * values that were replaced via interpolation
   */
  protected List m_missingTargetList;
  /**
   * Holds a list of training instance indexes that contained missing date
   * values (if a date time stamp is being used)
   */
  protected List m_missingTimeStampList;
  protected List m_missingTimeStampRows;
  /**
   * Logging object
   */
  protected Logger m_log;
  /** The lag maker to use */
  TSLagMaker m_lagMaker = new TSLagMaker();
  // used by the incremental method when detecting missing values in
  // targets/date
  private transient Instance m_previousPrimeInstance = null;
  private transient Instances m_missingBuffer = null;
  private transient boolean m_hadLeadingMissingPrime = false;
  private transient boolean m_first = false;
  private transient boolean m_atLeastOneNonMissingTimeStamp = false;

  /**
   * Main method for running this class from the command line
   *
   * @param args general and scheme-specific command line arguments
   */
  public static void main(String[] args) {
    try {
      /*
       * Instances train = new Instances(new BufferedReader(new
       * FileReader(args[0]))); WekaForecaster wf = new WekaForecaster();
       * ArrayList fieldsToForecast = new ArrayList();
       * fieldsToForecast.add(args[1]);
       * wf.setFieldsToForecast(fieldsToForecast);
       * wf.setPrimaryPeriodicFieldName(args[2]); Instances trans =
       * wf.getTransformedData(train); System.out.println(trans);
       */

      WekaForecaster fs = new WekaForecaster();
      fs.runForecaster(fs, args);
    } catch (Exception ex) {
      ex.printStackTrace();
    }
  }

    /**
     * Check whether the base learner requires special serialization
     *
     * @return true if base learner requires special serialization, false otherwise
     */
    public boolean baseModelHasSerializer() {
        return m_forecaster instanceof BaseModelSerializer;
    }

  /**
   * Save underlying classifier
   *
   * @param filepath the path of the file to save the base model to
   * @throws Exception
   */
  public void saveBaseModel(String filepath) throws Exception {
    if (baseModelHasSerializer()) {
      for (int i = 0; i < m_singleTargetForecasters.size(); i++)
        ((BaseModelSerializer) m_singleTargetForecasters.get(i).getWrappedClassifier()).serializeModel(filepath + ".base" + i);
    }
  }

  /**
   * Load serialized classifier
   *
   * @param filepath the path of the file to load the base model from
   * @throws Exception
   */
  public void loadBaseModel(String filepath) throws Exception {
    if (baseModelHasSerializer()) {
      for (int i = 0; i < m_singleTargetForecasters.size(); i++)
        ((BaseModelSerializer) m_singleTargetForecasters.get(i).getWrappedClassifier()).loadSerializedModel(filepath + ".base" + i);
    }
  }

  /**
   * Serialize model state
   *
   * @param filepath the path of the file to save the model state to
   * @throws Exception
   */
  public void serializeState(String filepath) throws Exception {
    if (usesState()) {
      for (int i = 0; i < m_singleTargetForecasters.size(); i++)
        ((StateDependentPredictor) m_singleTargetForecasters.get(i).getWrappedClassifier()).serializeState(filepath + ".state" + i);
    }
  }

  /**
   * Load serialized model state
   *
   * @param filepath the path of the file to save the model state from
   * @throws Exception
   */
  public void loadSerializedState(String filepath) throws Exception {
    if (usesState()) {
      for (int i = 0; i < m_singleTargetForecasters.size(); i++)
        ((StateDependentPredictor) m_singleTargetForecasters.get(i).getWrappedClassifier()).loadSerializedState(filepath + ".state" + i);
    }
  }

  /**
   * Check whether the base learner requires operations regarding state
   *
   * @return true if base learner uses state-based predictions, false otherwise
   */
  public boolean usesState() {
    return m_forecaster instanceof StateDependentPredictor;
  }

  /**
   * Reset model state.
   */
  public void clearPreviousState() {
    if (usesState()) {
      for (int i = 0; i < m_singleTargetForecasters.size(); i++)
        ((StateDependentPredictor) m_singleTargetForecasters.get(i).getWrappedClassifier()).clearPreviousState();
    }
  }

  /**
   * Load state into model.
   */
  public void setPreviousState(List previousState) {
    if (usesState()) {
      for (int i = 0; i < m_singleTargetForecasters.size(); i++)
        ((StateDependentPredictor) m_singleTargetForecasters.get(i).getWrappedClassifier()).setPreviousState(previousState.get(i));
    }
  }

  /**
   * Get the last set state of the model.
   *
   * @return the state of the model to be used in next prediction
   */
  public List getPreviousState() {
    List state = new ArrayList<>();

    if (usesState()) {
      for (int i = 0; i < m_singleTargetForecasters.size(); i++)
        state.add(i, ((StateDependentPredictor) m_singleTargetForecasters.get(i).getWrappedClassifier()).getPreviousState());
    }
    return state;
  }

  /**
   * Provides a short name that describes the underlying algorithm in some way.
   *
   * @return a short description of this forecaster.
   */
  @Override
  public String getAlgorithmName() {
    if (m_forecaster != null) {
      String spec = getForecasterSpec();
      spec = spec.replace("weka.classifiers.", "");
      spec = spec.replace("functions.", "");
      spec = spec.replace("bayes.", "");
      spec = spec.replace("rules.", "");
      spec = spec.replace("trees.", "");
      spec = spec.replace("meta.", "");
      spec = spec.replace("lazy.", "");
      spec = spec.replace("supportVector.", "");
      return spec;
    }

    return "";
  }

  /**
   * Get the TSLagMaker that we are using. All options pertaining to lag
   * creation, periodic attributes etc. are set via the lag maker.
   *
   * @return the TSLagMaker that we are using.
   */
  @Override
  public TSLagMaker getTSLagMaker() {
    return m_lagMaker;
  }

  /**
   * Set the TSLagMaker to use. All options pertaining to lag creation, periodic
   * attributes etc. are set via the lag maker.
   *
   * @param lagMaker the TSLagMaker to use.
   */
  @Override
  public void setTSLagMaker(TSLagMaker lagMaker) {
    m_lagMaker = lagMaker;
  }

  /**
   * Returns an enumeration describing the available options.
   *
   * @return an enumeration of all the available options.
   */
  @Override
  public Enumeration