weka.classifiers.timeseries.eval.graph.JFreeChartDriver Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of timeseriesForecasting Show documentation
Show all versions of timeseriesForecasting Show documentation
Provides a time series forecasting environment for Weka. Includes a wrapper for Weka regression schemes that automates the process of creating lagged variables and date-derived periodic variables and provides the ability to do closed-loop forecasting. New evaluation routines are provided by a special evaluation module and graphing of predictions/forecasts are provided via the JFreeChart library. Includes both command-line and GUI user interfaces. Sample time series data can be found in ${WEKA_HOME}/packages/timeseriesForecasting/sample-data.
The newest version!
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* JFreeChartDriver.java
* Copyright (C) 2010-2016 University of Waikato, Hamilton, New Zealand
*/
package weka.classifiers.timeseries.eval.graph;
import java.awt.BasicStroke;
import java.awt.Image;
import java.io.File;
import java.util.List;
import javax.swing.JPanel;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.ChartUtilities;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.DateAxis;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.axis.ValueAxis;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.xy.XYErrorRenderer;
import org.jfree.chart.title.TextTitle;
import org.jfree.data.xy.XYIntervalSeries;
import org.jfree.data.xy.XYIntervalSeriesCollection;
import weka.classifiers.evaluation.NumericPrediction;
import weka.classifiers.timeseries.AbstractForecaster;
import weka.classifiers.timeseries.TSForecaster;
import weka.classifiers.timeseries.WekaForecaster;
import weka.filters.supervised.attribute.TSLagMaker;
import weka.classifiers.timeseries.core.TSLagUser;
import weka.classifiers.timeseries.eval.ErrorModule;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
/**
* A Graph driver that uses the JFreeChart library.
*
* @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
* @version $Revision: 52593 $
*/
public class JFreeChartDriver extends GraphDriver {
protected JFreeChart getPredictedTargetsChart(TSForecaster forecaster,
ErrorModule preds, List targetNames, int stepNumber,
int instanceNumOffset, Instances data) {
if (forecaster instanceof TSLagUser && data != null) {
TSLagMaker lagMaker = ((TSLagUser) forecaster).getTSLagMaker();
if (lagMaker.getAdjustForTrends()
&& !lagMaker.isUsingAnArtificialTimeIndex()) {
// fill in any missing time stamps only
data = new Instances(data);
data = weka.classifiers.timeseries.core.Utils.replaceMissing(data,
null, lagMaker.getTimeStampField(), true,
lagMaker.getPeriodicity(), lagMaker.getSkipEntries());
}
}
// set up a collection of predicted and actual series
XYIntervalSeriesCollection xyDataset = new XYIntervalSeriesCollection();
for (String target : targetNames) {
XYIntervalSeries targetSeries = new XYIntervalSeries(target + "-actual",
false, false);
xyDataset.addSeries(targetSeries);
targetSeries = new XYIntervalSeries(target + "-predicted", false, false);
xyDataset.addSeries(targetSeries);
}
ValueAxis timeAxis = null;
NumberAxis valueAxis = new NumberAxis("");
valueAxis.setAutoRangeIncludesZero(false);
int timeIndex = -1;
boolean timeAxisIsDate = false;
if (forecaster instanceof TSLagUser && data != null) {
TSLagMaker lagMaker = ((TSLagUser) forecaster).getTSLagMaker();
if (!lagMaker.isUsingAnArtificialTimeIndex()
&& lagMaker.getAdjustForTrends()) {
String timeName = lagMaker.getTimeStampField();
if (data.attribute(timeName).isDate()) {
timeAxis = new DateAxis("");
timeAxisIsDate = true;
timeIndex = data.attribute(timeName).index();
}
}
}
if (timeAxis == null) {
timeAxis = new NumberAxis("");
((NumberAxis) timeAxis).setAutoRangeIncludesZero(false);
}
// now populate the series
boolean hasConfidenceIntervals = false;
for (int i = 0; i < targetNames.size(); i++) {
String targetName = targetNames.get(i);
List predsForI = preds
.getPredictionsForTarget(targetName);
int predIndex = xyDataset.indexOf(targetName + "-predicted");
int actualIndex = xyDataset.indexOf(targetName + "-actual");
XYIntervalSeries predSeries = xyDataset.getSeries(predIndex);
XYIntervalSeries actualSeries = xyDataset.getSeries(actualIndex);
for (int j = 0; j < predsForI.size(); j++) {
double x = Utils.missingValue();
if (timeAxisIsDate) {
if (instanceNumOffset + j + stepNumber - 1 < data.numInstances()) {
x = data.instance(instanceNumOffset + j + stepNumber - 1).value(
timeIndex);
}
} else {
x = instanceNumOffset + j + stepNumber;
}
double yPredicted = predsForI.get(j).predicted();
double yHigh = yPredicted;
double yLow = yPredicted;
double[][] conf = predsForI.get(j).predictionIntervals();
if (conf.length > 0) {
yLow = conf[0][0];
yHigh = conf[0][1];
hasConfidenceIntervals = true;
}
if (!Utils.isMissingValue(x) && !Utils.isMissingValue(yPredicted)) {
if (predSeries != null) {
predSeries.add(x, x, x, yPredicted, yLow, yHigh);
}
// System.err.println("* " + yPredicted + " " + x);
}
double yActual = predsForI.get(j).actual();
if (!Utils.isMissingValue(x) && !Utils.isMissingValue(yActual)) {
if (actualSeries != null) {
actualSeries.add(x, x, x, yActual, yActual, yActual);
}
}
}
}
// set up the chart
String title = "" + stepNumber + " step-ahead predictions for: ";
for (String s : targetNames) {
title += s + ",";
}
title = title.substring(0, title.lastIndexOf(","));
/*
* String algoSpec = forecaster.getAlgorithmName(); title += " (" + algoSpec
* + ")";
*/
if (forecaster instanceof WekaForecaster && hasConfidenceIntervals) {
double confPerc = ((WekaForecaster) forecaster).getConfidenceLevel() * 100.0;
title += " [" + Utils.doubleToString(confPerc, 0) + "% conf. intervals]";
}
XYErrorRenderer renderer = new XYErrorRenderer();
renderer.setBaseLinesVisible(true);
renderer.setDrawXError(false);
renderer.setDrawYError(true);
// renderer.setShapesFilled(true);
XYPlot plot = new XYPlot(xyDataset, timeAxis, valueAxis, renderer);
JFreeChart chart = new JFreeChart(title, JFreeChart.DEFAULT_TITLE_FONT,
plot, true);
chart.setBackgroundPaint(java.awt.Color.white);
TextTitle chartTitle = chart.getTitle();
String fontName = chartTitle.getFont().getFontName();
java.awt.Font newFont = new java.awt.Font(fontName, java.awt.Font.PLAIN, 12);
chartTitle.setFont(newFont);
return chart;
}
/**
* Save a chart to a file.
*
* @param chart the chart to save
* @param filename the filename to save to
* @param width width of the saved image
* @param height height of the saved image
* @throws Exception if the chart can't be saved for some reason
*/
@Override
public void saveChartToFile(JPanel chart, String filename, int width,
int height) throws Exception {
if (!(chart instanceof ChartPanel)) {
throw new Exception("Chart is not a JFreeChart!");
}
if (filename.toLowerCase().lastIndexOf(".png") < 0) {
filename = filename + ".png";
}
ChartUtilities.saveChartAsPNG(new File(filename),
((ChartPanel) chart).getChart(), width, height);
}
/**
* Get an image representation of the supplied chart.
*
* @param chart the chart to get an image of.
* @param width width of the chart
* @param height height of the chart
* @return an Image of the chart
* @throws Exception if the image can't be created for some reason
*/
@Override
public Image getImageFromChart(JPanel chart, int width, int height)
throws Exception {
if (!(chart instanceof ChartPanel)) {
throw new Exception("Chart is not a JFreeChart!");
}
Image result = ((ChartPanel) chart).getChart().createBufferedImage(width,
height);
return result;
}
/**
* Return the graph encapsulated in a panel.
*
* @param width the width in pixels of the graph
* @param height the height in pixels of the graph
* @param forecaster the forecaster
* @param preds an ErrorModule that contains predictions for all targets for
* the specified step ahead. Targets are in the same order returned
* by TSForecaster.getFieldsToForecast()
* @param targetNames the list of target names to plot
* @param stepNumber which step ahead to graph for specified targets
* @param instanceNumOffset how far into the data the predictions start from
* @param data the instances that these predictions correspond to (may be
* null)
* @return an image of the graph
*/
@Override
public JPanel getGraphPanelTargets(TSForecaster forecaster,
ErrorModule preds, List targetNames, int stepNumber,
int instanceNumOffset, Instances data) throws Exception {
JFreeChart chart = getPredictedTargetsChart(forecaster, preds, targetNames,
stepNumber, instanceNumOffset, data);
ChartPanel result = new ChartPanel(chart, false, true, true, true, false);
return result;
}
protected JFreeChart getPredictedStepsChart(TSForecaster forecaster,
List preds, String targetName, List stepsToPlot,
int instanceNumOffset, Instances data) {
if (forecaster instanceof TSLagUser && data != null) {
TSLagMaker lagMaker = ((TSLagUser) forecaster).getTSLagMaker();
if (lagMaker.getAdjustForTrends()
&& !lagMaker.isUsingAnArtificialTimeIndex()) {
// fill in any missing time stamps only
data = new Instances(data);
data = weka.classifiers.timeseries.core.Utils.replaceMissing(data,
null, lagMaker.getTimeStampField(), true,
lagMaker.getPeriodicity(), lagMaker.getSkipEntries());
}
}
// set up a collection of predicted series
XYIntervalSeriesCollection xyDataset = new XYIntervalSeriesCollection();
XYIntervalSeries targetSeries = new XYIntervalSeries(targetName, false,
false);
xyDataset.addSeries(targetSeries);
// for (int i = 0; i < preds.size(); i++) {
for (int z = 0; z < stepsToPlot.size(); z++) {
int i = stepsToPlot.get(z);
i--;
// ignore out of range steps
if (i < 0 || i >= preds.size()) {
continue;
}
String step = "-steps";
if (i == 0) {
step = "-step";
}
targetSeries = new XYIntervalSeries(targetName + "_" + (i + 1) + step
+ "-ahead", false, false);
xyDataset.addSeries(targetSeries);
}
ValueAxis timeAxis = null;
NumberAxis valueAxis = new NumberAxis("");
valueAxis.setAutoRangeIncludesZero(false);
int timeIndex = -1;
boolean timeAxisIsDate = false;
if (forecaster instanceof TSLagUser && data != null) {
TSLagMaker lagMaker = ((TSLagUser) forecaster).getTSLagMaker();
if (!lagMaker.isUsingAnArtificialTimeIndex()
&& lagMaker.getAdjustForTrends()) {
String timeName = lagMaker.getTimeStampField();
if (data.attribute(timeName).isDate()) {
timeAxis = new DateAxis("");
timeAxisIsDate = true;
timeIndex = data.attribute(timeName).index();
}
}
}
if (timeAxis == null) {
timeAxis = new NumberAxis("");
((NumberAxis) timeAxis).setAutoRangeIncludesZero(false);
}
// now populate the series
// for (int i = 0; i < preds.size(); i++) {
boolean doneActual = false;
boolean hasConfidenceIntervals = false;
for (int z = 0; z < stepsToPlot.size(); z++) {
int i = stepsToPlot.get(z);
i--;
// ignore out of range steps
if (i < 0 || i >= preds.size()) {
continue;
}
ErrorModule predsForStepI = preds.get(i);
List predsForTargetAtI = predsForStepI
.getPredictionsForTarget(targetName);
String step = "-steps";
if (i == 0) {
step = "-step";
}
int predIndex = xyDataset.indexOf(targetName + "_" + (i + 1) + step
+ "-ahead");
XYIntervalSeries predSeries = xyDataset.getSeries(predIndex);
XYIntervalSeries actualSeries = null;
if (!doneActual) {
int actualIndex = xyDataset.indexOf(targetName);
actualSeries = xyDataset.getSeries(actualIndex);
}
for (int j = 0; j < predsForTargetAtI.size(); j++) {
double x = Utils.missingValue();
if (timeAxisIsDate) {
if (instanceNumOffset + j + i < data.numInstances()) {
x = data.instance(instanceNumOffset + j + i).value(timeIndex);
}
} else {
x = instanceNumOffset + j + i;
}
double yPredicted = predsForTargetAtI.get(j).predicted();
double yHigh = yPredicted;
double yLow = yPredicted;
double[][] conf = predsForTargetAtI.get(j).predictionIntervals();
if (conf.length > 0) {
yLow = conf[0][0];
yHigh = conf[0][1];
hasConfidenceIntervals = true;
}
if (!Utils.isMissingValue(x) && !Utils.isMissingValue(yPredicted)) {
if (predSeries != null) {
predSeries.add(x, x, x, yPredicted, yLow, yHigh);
}
// System.err.println("* " + yPredicted + " " + x);
}
if (!doneActual && actualSeries != null) {
double yActual = predsForTargetAtI.get(j).actual();
if (!Utils.isMissingValue(x) && !Utils.isMissingValue(yActual)) {
actualSeries.add(x, x, x, yActual, yActual, yActual);
}
}
}
if (actualSeries != null) {
doneActual = true;
}
}
// set up the chart
String title = "";
for (int i : stepsToPlot) {
title += i + ",";
}
title = title.substring(0, title.lastIndexOf(","));
title += " step-ahead predictions for " + targetName;
/*
* String algoSpec = forecaster.getAlgorithmName(); title += " (" + algoSpec
* + ")";
*/
if (forecaster instanceof WekaForecaster && hasConfidenceIntervals) {
double confPerc = ((WekaForecaster) forecaster).getConfidenceLevel() * 100.0;
title += " [" + Utils.doubleToString(confPerc, 0) + "% conf. intervals]";
}
XYErrorRenderer renderer = new XYErrorRenderer();
renderer.setBaseLinesVisible(true);
// renderer.setShapesFilled(true);
XYPlot plot = new XYPlot(xyDataset, timeAxis, valueAxis, renderer);
JFreeChart chart = new JFreeChart(title, JFreeChart.DEFAULT_TITLE_FONT,
plot, true);
chart.setBackgroundPaint(java.awt.Color.white);
TextTitle chartTitle = chart.getTitle();
String fontName = chartTitle.getFont().getFontName();
java.awt.Font newFont = new java.awt.Font(fontName, java.awt.Font.PLAIN, 12);
chartTitle.setFont(newFont);
return chart;
}
/**
* Return the graph encapsulated in a JPanel.
*
* @param forecaster the forecaster
* @param preds a list of ErrorModules, one for each consecutive step ahead
* prediction set
* @param targetName the name of the target field to plot
* @param stepsToPlot a list of step numbers for the step-ahead prediction
* sets to plot to plot for the specified target.
* @param instanceNumOffset how far into the data the predictions start from
* @param data the instances that these predictions correspond to (may be
* null)
* @return an image of the graph.
*/
@Override
public JPanel getGraphPanelSteps(TSForecaster forecaster,
List preds, String targetName, List stepsToPlot,
int instanceNumOffset, Instances data) throws Exception {
JFreeChart chart = getPredictedStepsChart(forecaster, preds, targetName,
stepsToPlot, instanceNumOffset, data);
ChartPanel result = new ChartPanel(chart, false, true, true, true, false);
return result;
}
protected JFreeChart getFutureForecastChart(TSForecaster forecaster,
List> preds, List targetNames,
Instances history) {
if (forecaster instanceof TSLagUser && history != null) {
TSLagMaker lagMaker = ((TSLagUser) forecaster).getTSLagMaker();
if (lagMaker.getAdjustForTrends()
&& !lagMaker.isUsingAnArtificialTimeIndex()) {
// fill in any missing time stamps only
history = new Instances(history);
history = weka.classifiers.timeseries.core.Utils.replaceMissing(
history, null, lagMaker.getTimeStampField(), true,
lagMaker.getPeriodicity(), lagMaker.getSkipEntries());
}
}
// set up a collection of series
XYIntervalSeriesCollection xyDataset = new XYIntervalSeriesCollection();
if (history != null) {
// add actual historical data values
for (String targetName : targetNames) {
XYIntervalSeries targetSeries = new XYIntervalSeries(targetName, false,
false);
xyDataset.addSeries(targetSeries);
}
}
// add predicted series
for (String targetName : targetNames) {
XYIntervalSeries targetSeries = new XYIntervalSeries(targetName
+ "-predicted", false, false);
xyDataset.addSeries(targetSeries);
}
ValueAxis timeAxis = null;
NumberAxis valueAxis = new NumberAxis("");
valueAxis.setAutoRangeIncludesZero(false);
int timeIndex = -1;
boolean timeAxisIsDate = false;
double artificialTimeStart = 0;
double lastRealTimeValue = Utils.missingValue();
if (forecaster instanceof TSLagUser && history != null) {
TSLagMaker lagMaker = ((TSLagUser) forecaster).getTSLagMaker();
if (!lagMaker.isUsingAnArtificialTimeIndex()
&& lagMaker.getAdjustForTrends()) {
String timeName = lagMaker.getTimeStampField();
if (history.attribute(timeName).isDate()) {
timeAxis = new DateAxis("");
timeAxisIsDate = true;
timeIndex = history.attribute(timeName).index();
}
} else {
try {
artificialTimeStart = (history != null) ? 1 : lagMaker
.getArtificialTimeStartValue() + 1;
} catch (Exception ex) {
}
}
}
if (timeAxis == null) {
timeAxis = new NumberAxis("");
((NumberAxis) timeAxis).setAutoRangeIncludesZero(false);
}
boolean hasConfidenceIntervals = false;
// now populate the series
if (history != null) {
// do the actuals first
for (int i = 0; i < history.numInstances(); i++) {
Instance current = history.instance(i);
for (String targetName : targetNames) {
int dataIndex = history.attribute(targetName.trim()).index();
if (dataIndex >= 0) {
XYIntervalSeries actualSeries = null;
int actualIndex = xyDataset.indexOf(targetName);
actualSeries = xyDataset.getSeries(actualIndex);
double x = Utils.missingValue();
if (timeAxisIsDate) {
x = current.value(timeIndex);
if (!Utils.isMissingValue(x)) {
lastRealTimeValue = x;
}
} else {
x = artificialTimeStart;
}
double y = Utils.missingValue();
y = current.value(dataIndex);
if (!Utils.isMissingValue(x) && !Utils.isMissingValue(y)) {
if (actualSeries != null) {
actualSeries.add(x, x, x, y, y, y);
}
}
}
}
if (!timeAxisIsDate) {
artificialTimeStart++;
}
}
}
// now do the futures
List forecasterTargets = AbstractForecaster.stringToList(forecaster
.getFieldsToForecast());
// loop over the steps
for (int j = 0; j < preds.size(); j++) {
List predsForStepJ = preds.get(j);
// advance the real time index (if appropriate)
if (timeAxisIsDate) {
lastRealTimeValue = ((TSLagUser) forecaster).getTSLagMaker()
.advanceSuppliedTimeValue(lastRealTimeValue);
}
for (String targetName : targetNames) {
// look up this requested target in the list that the forecaster
// has predicted
int predIndex = forecasterTargets.indexOf(targetName.trim());
if (predIndex >= 0) {
NumericPrediction predsForTargetAtStepJ = predsForStepJ
.get(predIndex);
XYIntervalSeries predSeries = null;
int datasetIndex = xyDataset.indexOf(targetName + "-predicted");
predSeries = xyDataset.getSeries(datasetIndex);
if (predSeries != null) {
double y = predsForTargetAtStepJ.predicted();
double x = Utils.missingValue();
double yHigh = y;
double yLow = y;
double[][] conf = predsForTargetAtStepJ.predictionIntervals();
if (conf.length > 0) {
yLow = conf[0][0];
yHigh = conf[0][1];
hasConfidenceIntervals = true;
}
if (!timeAxisIsDate) {
x = artificialTimeStart;
} else {
x = lastRealTimeValue;
}
if (!Utils.isMissingValue(x) && !Utils.isMissingValue(y)) {
predSeries.add(x, x, x, y, yLow, yHigh);
}
}
}
}
// advance the artificial time index (if appropriate)
if (!timeAxisIsDate) {
artificialTimeStart++;
}
}
String title = "Future forecast for: ";
for (String s : targetNames) {
title += s + ",";
}
title = title.substring(0, title.lastIndexOf(","));
/*
* String algoSpec = forecaster.getAlgorithmName(); title += " (" + algoSpec
* + ")";
*/
if (forecaster instanceof WekaForecaster && hasConfidenceIntervals) {
double confPerc = ((WekaForecaster) forecaster).getConfidenceLevel() * 100.0;
title += " [" + Utils.doubleToString(confPerc, 0) + "% conf. intervals]";
}
XYErrorRenderer renderer = new XYErrorRenderer();
// renderer.setShapesFilled(true);
XYPlot plot = new XYPlot(xyDataset, timeAxis, valueAxis, renderer);
// renderer = (XYErrorRenderer)plot.getRenderer();
if (history != null) {
for (String targetName : targetNames) {
XYIntervalSeries predSeries = null;
int predIndex = xyDataset.indexOf(targetName + "-predicted");
predSeries = xyDataset.getSeries(predIndex);
XYIntervalSeries actualSeries = null;
int actualIndex = xyDataset.indexOf(targetName);
actualSeries = xyDataset.getSeries(actualIndex);
if (actualSeries != null && predSeries != null) {
// match the color of the actual series
java.awt.Paint actualPaint = renderer.lookupSeriesPaint(actualIndex);
renderer.setSeriesPaint(predIndex, actualPaint);
// now set the line style to dashed
BasicStroke dashed = new BasicStroke(1.5f, BasicStroke.CAP_BUTT,
BasicStroke.JOIN_MITER, 10.0f, new float[] { 5.0f }, 0.0f);
renderer.setSeriesStroke(predIndex, dashed);
}
}
}
renderer.setBaseLinesVisible(true);
renderer.setDrawXError(false);
renderer.setDrawYError(true);
JFreeChart chart = new JFreeChart(title, JFreeChart.DEFAULT_TITLE_FONT,
plot, true);
chart.setBackgroundPaint(java.awt.Color.white);
TextTitle chartTitle = chart.getTitle();
String fontName = chartTitle.getFont().getFontName();
java.awt.Font newFont = new java.awt.Font(fontName, java.awt.Font.PLAIN, 12);
chartTitle.setFont(newFont);
return chart;
}
/**
* Return the graph encapsulated in a JPanel
*
* @param forecaster the forecaster
* @param preds a list of list of predictions for *all* targets. The outer
* list is indexed by step number (i.e. the first entry is the 1-step
* ahead forecasts, the second is the 2-steps ahead forecasts etc.)
* and the inner list is indexed by target in the same order as the
* list of targets returned by TSForecaster.getFieldsToForecast().
* @param targetNames the list of target names to plot
* @param history a set of instances from which predictions are assumed to
* follow on from. May be null, in which case just the predictions
* are plotted.
* @return an image of the graph
*/
@Override
public JPanel getPanelFutureForecast(TSForecaster forecaster,
List> preds, List targetNames,
Instances history) throws Exception {
JFreeChart chart = getFutureForecastChart(forecaster, preds, targetNames,
history);
ChartPanel result = new ChartPanel(chart, false, true, true, true, false);
return result;
}
}