weka.gui.explorer.ClassifierPanel Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-dev Show documentation
Show all versions of weka-dev Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This version represents the developer version, the
"bleeding edge" of development, you could say. New functionality gets added
to this version.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
/*
* ClassifierPanel.java
* Copyright (C) 1999-2013 University of Waikato, Hamilton, New Zealand
*
*/
package weka.gui.explorer;
import java.awt.BorderLayout;
import java.awt.Dimension;
import java.awt.FlowLayout;
import java.awt.Font;
import java.awt.GridBagConstraints;
import java.awt.GridBagLayout;
import java.awt.GridLayout;
import java.awt.Insets;
import java.awt.Point;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.InputEvent;
import java.awt.event.MouseAdapter;
import java.awt.event.MouseEvent;
import java.beans.PropertyChangeEvent;
import java.beans.PropertyChangeListener;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Random;
import java.util.Vector;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import javax.swing.BorderFactory;
import javax.swing.ButtonGroup;
import javax.swing.DefaultComboBoxModel;
import javax.swing.JButton;
import javax.swing.JCheckBox;
import javax.swing.JComboBox;
import javax.swing.JDialog;
import javax.swing.JFileChooser;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JMenu;
import javax.swing.JMenuItem;
import javax.swing.JOptionPane;
import javax.swing.JPanel;
import javax.swing.JPopupMenu;
import javax.swing.JRadioButton;
import javax.swing.JScrollPane;
import javax.swing.JTextArea;
import javax.swing.JTextField;
import javax.swing.JViewport;
import javax.swing.SwingConstants;
import javax.swing.event.ChangeEvent;
import javax.swing.event.ChangeListener;
import javax.swing.filechooser.FileFilter;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.CostMatrix;
import weka.classifiers.Evaluation;
import weka.classifiers.Sourcable;
import weka.classifiers.evaluation.CostCurve;
import weka.classifiers.evaluation.MarginCurve;
import weka.classifiers.evaluation.Prediction;
import weka.classifiers.evaluation.ThresholdCurve;
import weka.classifiers.evaluation.output.prediction.AbstractOutput;
import weka.classifiers.evaluation.output.prediction.Null;
import weka.classifiers.pmml.consumer.PMMLClassifier;
import weka.core.Attribute;
import weka.core.BatchPredictor;
import weka.core.Capabilities;
import weka.core.CapabilitiesHandler;
import weka.core.Drawable;
import weka.core.Environment;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionHandler;
import weka.core.Range;
import weka.core.SerializedObject;
import weka.core.Utils;
import weka.core.Version;
import weka.core.converters.ConverterUtils.DataSource;
import weka.core.converters.IncrementalConverter;
import weka.core.converters.Loader;
import weka.core.pmml.PMMLFactory;
import weka.core.pmml.PMMLModel;
import weka.gui.CostMatrixEditor;
import weka.gui.EvaluationMetricSelectionDialog;
import weka.gui.ExtensionFileFilter;
import weka.gui.GenericObjectEditor;
import weka.gui.Logger;
import weka.gui.PropertyDialog;
import weka.gui.PropertyPanel;
import weka.gui.ResultHistoryPanel;
import weka.gui.SaveBuffer;
import weka.gui.SetInstancesPanel;
import weka.gui.SysErrLog;
import weka.gui.TaskLogger;
import weka.gui.beans.CostBenefitAnalysis;
import weka.gui.explorer.Explorer.CapabilitiesFilterChangeEvent;
import weka.gui.explorer.Explorer.CapabilitiesFilterChangeListener;
import weka.gui.explorer.Explorer.ExplorerPanel;
import weka.gui.explorer.Explorer.LogHandler;
import weka.gui.graphvisualizer.BIFFormatException;
import weka.gui.graphvisualizer.GraphVisualizer;
import weka.gui.treevisualizer.PlaceNode2;
import weka.gui.treevisualizer.TreeVisualizer;
import weka.gui.visualize.PlotData2D;
import weka.gui.visualize.ThresholdVisualizePanel;
import weka.gui.visualize.VisualizePanel;
import weka.gui.visualize.plugins.ErrorVisualizePlugin;
import weka.gui.visualize.plugins.GraphVisualizePlugin;
import weka.gui.visualize.plugins.TreeVisualizePlugin;
import weka.gui.visualize.plugins.VisualizePlugin;
/**
* This panel allows the user to select and configure a classifier, set the
* attribute of the current dataset to be used as the class, and evaluate the
* classifier using a number of testing modes (test on the training data,
* train/test on a percentage split, n-fold cross-validation, test on a separate
* split). The results of classification runs are stored in a result history so
* that previous results are accessible.
*
* @author Len Trigg ([email protected])
* @author Mark Hall ([email protected])
* @author Richard Kirkby ([email protected])
* @version $Revision: 10216 $
*/
public class ClassifierPanel extends JPanel implements
CapabilitiesFilterChangeListener, ExplorerPanel, LogHandler {
/** for serialization. */
static final long serialVersionUID = 6959973704963624003L;
/** the parent frame. */
protected Explorer m_Explorer = null;
/** The filename extension that should be used for model files. */
public static String MODEL_FILE_EXTENSION = ".model";
/** The filename extension that should be used for PMML xml files. */
public static String PMML_FILE_EXTENSION = ".xml";
/** Lets the user configure the classifier. */
protected GenericObjectEditor m_ClassifierEditor = new GenericObjectEditor();
/** The panel showing the current classifier selection. */
protected PropertyPanel m_CEPanel = new PropertyPanel(m_ClassifierEditor);
/** The output area for classification results. */
protected JTextArea m_OutText = new JTextArea(20, 40);
/** The destination for log/status messages. */
protected Logger m_Log = new SysErrLog();
/** The buffer saving object for saving output. */
SaveBuffer m_SaveOut = new SaveBuffer(m_Log, this);
/** A panel controlling results viewing. */
protected ResultHistoryPanel m_History = new ResultHistoryPanel(m_OutText);
/** Lets the user select the class column. */
protected JComboBox m_ClassCombo = new JComboBox();
/** Click to set test mode to cross-validation. */
protected JRadioButton m_CVBut = new JRadioButton("Cross-validation");
/** Click to set test mode to generate a % split. */
protected JRadioButton m_PercentBut = new JRadioButton("Percentage split");
/** Click to set test mode to test on training data. */
protected JRadioButton m_TrainBut = new JRadioButton("Use training set");
/** Click to set test mode to a user-specified test set. */
protected JRadioButton m_TestSplitBut = new JRadioButton("Supplied test set");
/**
* Check to save the predictions in the results list for visualizing later on.
*/
protected JCheckBox m_StorePredictionsBut = new JCheckBox(
"Store predictions for visualization");
/**
* Check to have the point size in error plots proportional to the prediction
* margin (classification only)
*/
protected JCheckBox m_errorPlotPointSizeProportionalToMargin = new JCheckBox(
"Error plot point size proportional to margin");
/** Check to output the model built from the training data. */
protected JCheckBox m_OutputModelBut = new JCheckBox("Output model");
/** Check to output true/false positives, precision/recall for each class. */
protected JCheckBox m_OutputPerClassBut = new JCheckBox(
"Output per-class stats");
/** Check to output a confusion matrix. */
protected JCheckBox m_OutputConfusionBut = new JCheckBox(
"Output confusion matrix");
/** Check to output entropy statistics. */
protected JCheckBox m_OutputEntropyBut = new JCheckBox(
"Output entropy evaluation measures");
/** Lets the user configure the ClassificationOutput. */
protected GenericObjectEditor m_ClassificationOutputEditor = new GenericObjectEditor(
true);
/** ClassificationOutput configuration. */
protected PropertyPanel m_ClassificationOutputPanel = new PropertyPanel(
m_ClassificationOutputEditor);
/** the range of attributes to output. */
protected Range m_OutputAdditionalAttributesRange = null;
/** Check to evaluate w.r.t a cost matrix. */
protected JCheckBox m_EvalWRTCostsBut = new JCheckBox(
"Cost-sensitive evaluation");
/** for the cost matrix. */
protected JButton m_SetCostsBut = new JButton("Set...");
/** Label by where the cv folds are entered. */
protected JLabel m_CVLab = new JLabel("Folds", SwingConstants.RIGHT);
/** The field where the cv folds are entered. */
protected JTextField m_CVText = new JTextField("10", 3);
/** Label by where the % split is entered. */
protected JLabel m_PercentLab = new JLabel("%", SwingConstants.RIGHT);
/** The field where the % split is entered. */
protected JTextField m_PercentText = new JTextField("66", 3);
/** The button used to open a separate test dataset. */
protected JButton m_SetTestBut = new JButton("Set...");
/** The frame used to show the test set selection panel. */
protected JFrame m_SetTestFrame;
/** The frame used to show the cost matrix editing panel. */
protected PropertyDialog m_SetCostsFrame;
/**
* Alters the enabled/disabled status of elements associated with each radio
* button.
*/
ActionListener m_RadioListener = new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
updateRadioLinks();
}
};
/** Button for further output/visualize options. */
JButton m_MoreOptions = new JButton("More options...");
/** User specified random seed for cross validation or % split. */
protected JTextField m_RandomSeedText = new JTextField("1", 3);
/** the label for the random seed textfield. */
protected JLabel m_RandomLab = new JLabel("Random seed for XVal / % Split",
SwingConstants.RIGHT);
/** Whether randomization is turned off to preserve order. */
protected JCheckBox m_PreserveOrderBut = new JCheckBox(
"Preserve order for % Split");
/**
* Whether to output the source code (only for classifiers importing
* Sourcable).
*/
protected JCheckBox m_OutputSourceCode = new JCheckBox("Output source code");
/** The name of the generated class (only applicable to Sourcable schemes). */
protected JTextField m_SourceCodeClass = new JTextField("WekaClassifier", 10);
/** Click to start running the classifier. */
protected JButton m_StartBut = new JButton("Start");
/** Click to stop a running classifier. */
protected JButton m_StopBut = new JButton("Stop");
/** Stop the class combo from taking up to much space. */
private final Dimension COMBO_SIZE = new Dimension(150,
m_StartBut.getPreferredSize().height);
/** The cost matrix editor for evaluation costs. */
protected CostMatrixEditor m_CostMatrixEditor = new CostMatrixEditor();
/** The main set of instances we're playing with. */
protected Instances m_Instances;
/** The loader used to load the user-supplied test set (if any). */
protected Loader m_TestLoader;
/** the class index for the supplied test set. */
protected int m_TestClassIndex = -1;
/** A thread that classification runs in. */
protected Thread m_RunThread;
/** The current visualization object. */
protected VisualizePanel m_CurrentVis = null;
/** Filter to ensure only model files are selected. */
protected FileFilter m_ModelFilter = new ExtensionFileFilter(
MODEL_FILE_EXTENSION, "Model object files");
protected FileFilter m_PMMLModelFilter = new ExtensionFileFilter(
PMML_FILE_EXTENSION, "PMML model files");
/** The file chooser for selecting model files. */
protected JFileChooser m_FileChooser = new JFileChooser(new File(
System.getProperty("user.dir")));
/** The user's list of selected evaluation metrics */
protected List m_selectedEvalMetrics = Evaluation
.getAllEvaluationMetricNames();
/* Register the property editors we need */
static {
GenericObjectEditor.registerEditors();
}
/**
* Creates the classifier panel.
*/
public ClassifierPanel() {
// Connect / configure the components
m_OutText.setEditable(false);
m_OutText.setFont(new Font("Monospaced", Font.PLAIN, 12));
m_OutText.setBorder(BorderFactory.createEmptyBorder(5, 5, 5, 5));
m_OutText.addMouseListener(new MouseAdapter() {
@Override
public void mouseClicked(MouseEvent e) {
if ((e.getModifiers() & InputEvent.BUTTON1_MASK) != InputEvent.BUTTON1_MASK) {
m_OutText.selectAll();
}
}
});
m_History.setBorder(BorderFactory
.createTitledBorder("Result list (right-click for options)"));
m_ClassifierEditor.setClassType(Classifier.class);
m_ClassifierEditor.setValue(ExplorerDefaults.getClassifier());
m_ClassifierEditor.addPropertyChangeListener(new PropertyChangeListener() {
@Override
public void propertyChange(PropertyChangeEvent e) {
m_StartBut.setEnabled(true);
// Check capabilities
Capabilities currentFilter = m_ClassifierEditor.getCapabilitiesFilter();
Classifier classifier = (Classifier) m_ClassifierEditor.getValue();
Capabilities currentSchemeCapabilities = null;
if (classifier != null && currentFilter != null
&& (classifier instanceof CapabilitiesHandler)) {
currentSchemeCapabilities = ((CapabilitiesHandler) classifier)
.getCapabilities();
if (!currentSchemeCapabilities.supportsMaybe(currentFilter)
&& !currentSchemeCapabilities.supports(currentFilter)) {
m_StartBut.setEnabled(false);
}
}
repaint();
}
});
m_ClassCombo.setToolTipText("Select the attribute to use as the class");
m_TrainBut.setToolTipText("Test on the same set that the classifier"
+ " is trained on");
m_CVBut.setToolTipText("Perform a n-fold cross-validation");
m_PercentBut.setToolTipText("Train on a percentage of the data and"
+ " test on the remainder");
m_TestSplitBut.setToolTipText("Test on a user-specified dataset");
m_StartBut.setToolTipText("Starts the classification");
m_StopBut.setToolTipText("Stops a running classification");
m_StorePredictionsBut
.setToolTipText("Store predictions in the result list for later "
+ "visualization");
m_errorPlotPointSizeProportionalToMargin
.setToolTipText("In classifier errors plots the point size will be "
+ "set proportional to the absolute value of the "
+ "prediction margin (affects classification only)");
m_OutputModelBut
.setToolTipText("Output the model obtained from the full training set");
m_OutputPerClassBut.setToolTipText("Output precision/recall & true/false"
+ " positives for each class");
m_OutputConfusionBut
.setToolTipText("Output the matrix displaying class confusions");
m_OutputEntropyBut
.setToolTipText("Output entropy-based evaluation measures");
m_EvalWRTCostsBut
.setToolTipText("Evaluate errors with respect to a cost matrix");
m_RandomLab.setToolTipText("The seed value for randomization");
m_RandomSeedText.setToolTipText(m_RandomLab.getToolTipText());
m_PreserveOrderBut
.setToolTipText("Preserves the order in a percentage split");
m_OutputSourceCode
.setToolTipText("Whether to output the built classifier as Java source code");
m_SourceCodeClass.setToolTipText("The classname of the built classifier");
m_FileChooser.addChoosableFileFilter(m_PMMLModelFilter);
m_FileChooser.setFileFilter(m_ModelFilter);
m_FileChooser.setFileSelectionMode(JFileChooser.FILES_ONLY);
m_ClassificationOutputEditor.setClassType(AbstractOutput.class);
m_ClassificationOutputEditor.setValue(new Null());
m_StorePredictionsBut.setSelected(ExplorerDefaults
.getClassifierStorePredictionsForVis());
m_OutputModelBut.setSelected(ExplorerDefaults.getClassifierOutputModel());
m_OutputPerClassBut.setSelected(ExplorerDefaults
.getClassifierOutputPerClassStats());
m_OutputConfusionBut.setSelected(ExplorerDefaults
.getClassifierOutputConfusionMatrix());
m_EvalWRTCostsBut.setSelected(ExplorerDefaults
.getClassifierCostSensitiveEval());
m_OutputEntropyBut.setSelected(ExplorerDefaults
.getClassifierOutputEntropyEvalMeasures());
m_RandomSeedText.setText("" + ExplorerDefaults.getClassifierRandomSeed());
m_PreserveOrderBut.setSelected(ExplorerDefaults
.getClassifierPreserveOrder());
m_OutputSourceCode.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
m_SourceCodeClass.setEnabled(m_OutputSourceCode.isSelected());
}
});
m_OutputSourceCode.setSelected(ExplorerDefaults
.getClassifierOutputSourceCode());
m_SourceCodeClass.setText(ExplorerDefaults.getClassifierSourceCodeClass());
m_SourceCodeClass.setEnabled(m_OutputSourceCode.isSelected());
m_ClassCombo.setEnabled(false);
m_ClassCombo.setPreferredSize(COMBO_SIZE);
m_ClassCombo.setMaximumSize(COMBO_SIZE);
m_ClassCombo.setMinimumSize(COMBO_SIZE);
m_CVBut.setSelected(true);
// see "testMode" variable in startClassifier
m_CVBut.setSelected(ExplorerDefaults.getClassifierTestMode() == 1);
m_PercentBut.setSelected(ExplorerDefaults.getClassifierTestMode() == 2);
m_TrainBut.setSelected(ExplorerDefaults.getClassifierTestMode() == 3);
m_TestSplitBut.setSelected(ExplorerDefaults.getClassifierTestMode() == 4);
m_PercentText.setText("" + ExplorerDefaults.getClassifierPercentageSplit());
m_CVText.setText("" + ExplorerDefaults.getClassifierCrossvalidationFolds());
updateRadioLinks();
ButtonGroup bg = new ButtonGroup();
bg.add(m_TrainBut);
bg.add(m_CVBut);
bg.add(m_PercentBut);
bg.add(m_TestSplitBut);
m_TrainBut.addActionListener(m_RadioListener);
m_CVBut.addActionListener(m_RadioListener);
m_PercentBut.addActionListener(m_RadioListener);
m_TestSplitBut.addActionListener(m_RadioListener);
m_SetTestBut.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
setTestSet();
}
});
m_EvalWRTCostsBut.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
m_SetCostsBut.setEnabled(m_EvalWRTCostsBut.isSelected());
if ((m_SetCostsFrame != null) && (!m_EvalWRTCostsBut.isSelected())) {
m_SetCostsFrame.setVisible(false);
}
}
});
m_CostMatrixEditor.setValue(new CostMatrix(1));
m_SetCostsBut.setEnabled(m_EvalWRTCostsBut.isSelected());
m_SetCostsBut.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
m_SetCostsBut.setEnabled(false);
if (m_SetCostsFrame == null) {
if (PropertyDialog.getParentDialog(ClassifierPanel.this) != null) {
m_SetCostsFrame = new PropertyDialog(PropertyDialog
.getParentDialog(ClassifierPanel.this), m_CostMatrixEditor, 100,
100);
} else {
m_SetCostsFrame = new PropertyDialog(PropertyDialog
.getParentFrame(ClassifierPanel.this), m_CostMatrixEditor, 100,
100);
}
m_SetCostsFrame.setTitle("Cost Matrix Editor");
// pd.setSize(250,150);
m_SetCostsFrame.addWindowListener(new java.awt.event.WindowAdapter() {
@Override
public void windowClosing(java.awt.event.WindowEvent p) {
m_SetCostsBut.setEnabled(m_EvalWRTCostsBut.isSelected());
if ((m_SetCostsFrame != null)
&& (!m_EvalWRTCostsBut.isSelected())) {
m_SetCostsFrame.setVisible(false);
}
}
});
m_SetCostsFrame.setVisible(true);
}
// do we need to change the size of the matrix?
int classIndex = m_ClassCombo.getSelectedIndex();
int numClasses = m_Instances.attribute(classIndex).numValues();
if (numClasses != ((CostMatrix) m_CostMatrixEditor.getValue())
.numColumns()) {
m_CostMatrixEditor.setValue(new CostMatrix(numClasses));
}
m_SetCostsFrame.setVisible(true);
}
});
m_StartBut.setEnabled(false);
m_StopBut.setEnabled(false);
m_StartBut.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
boolean proceed = true;
if (Explorer.m_Memory.memoryIsLow()) {
proceed = Explorer.m_Memory.showMemoryIsLow();
}
if (proceed) {
startClassifier();
}
}
});
m_StopBut.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
stopClassifier();
}
});
m_ClassCombo.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
int selected = m_ClassCombo.getSelectedIndex();
if (selected != -1) {
boolean isNominal = m_Instances.attribute(selected).isNominal();
m_OutputPerClassBut.setEnabled(isNominal);
m_OutputConfusionBut.setEnabled(isNominal);
}
updateCapabilitiesFilter(m_ClassifierEditor.getCapabilitiesFilter());
}
});
m_History.setHandleRightClicks(false);
// see if we can popup a menu for the selected result
m_History.getList().addMouseListener(new MouseAdapter() {
@Override
public void mouseClicked(MouseEvent e) {
if (((e.getModifiers() & InputEvent.BUTTON1_MASK) != InputEvent.BUTTON1_MASK)
|| e.isAltDown()) {
int index = m_History.getList().locationToIndex(e.getPoint());
if (index != -1) {
String name = m_History.getNameAtIndex(index);
visualize(name, e.getX(), e.getY());
} else {
visualize(null, e.getX(), e.getY());
}
}
}
});
m_MoreOptions.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
m_MoreOptions.setEnabled(false);
JPanel moreOptionsPanel = new JPanel();
moreOptionsPanel.setBorder(BorderFactory.createEmptyBorder(0, 5, 5, 5));
moreOptionsPanel.setLayout(new GridLayout(0, 1));
moreOptionsPanel.add(m_OutputModelBut);
moreOptionsPanel.add(m_OutputPerClassBut);
moreOptionsPanel.add(m_OutputEntropyBut);
moreOptionsPanel.add(m_OutputConfusionBut);
moreOptionsPanel.add(m_StorePredictionsBut);
moreOptionsPanel.add(m_errorPlotPointSizeProportionalToMargin);
JPanel classOutPanel = new JPanel(new FlowLayout(FlowLayout.LEFT));
classOutPanel.add(new JLabel("Output predictions"));
classOutPanel.add(m_ClassificationOutputPanel);
moreOptionsPanel.add(classOutPanel);
JPanel costMatrixOption = new JPanel(new FlowLayout(FlowLayout.LEFT));
costMatrixOption.add(m_EvalWRTCostsBut);
costMatrixOption.add(m_SetCostsBut);
moreOptionsPanel.add(costMatrixOption);
JPanel seedPanel = new JPanel(new FlowLayout(FlowLayout.LEFT));
seedPanel.add(m_RandomLab);
seedPanel.add(m_RandomSeedText);
moreOptionsPanel.add(seedPanel);
moreOptionsPanel.add(m_PreserveOrderBut);
JPanel sourcePanel = new JPanel(new FlowLayout(FlowLayout.LEFT));
m_OutputSourceCode.setEnabled(m_ClassifierEditor.getValue() instanceof Sourcable);
m_SourceCodeClass.setEnabled(m_OutputSourceCode.isEnabled()
&& m_OutputSourceCode.isSelected());
sourcePanel.add(m_OutputSourceCode);
sourcePanel.add(m_SourceCodeClass);
moreOptionsPanel.add(sourcePanel);
JPanel all = new JPanel();
all.setLayout(new BorderLayout());
JButton oK = new JButton("OK");
JPanel okP = new JPanel();
okP.setBorder(BorderFactory.createEmptyBorder(5, 5, 5, 5));
okP.setLayout(new GridLayout(1, 1, 5, 5));
okP.add(oK);
all.add(moreOptionsPanel, BorderLayout.CENTER);
all.add(okP, BorderLayout.SOUTH);
final JDialog jd = new JDialog(PropertyDialog
.getParentFrame(ClassifierPanel.this),
"Classifier evaluation options");
jd.getContentPane().setLayout(new BorderLayout());
jd.getContentPane().add(all, BorderLayout.CENTER);
jd.addWindowListener(new java.awt.event.WindowAdapter() {
@Override
public void windowClosing(java.awt.event.WindowEvent w) {
jd.dispose();
m_MoreOptions.setEnabled(true);
}
});
oK.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent a) {
m_MoreOptions.setEnabled(true);
jd.dispose();
}
});
jd.pack();
// panel height is only available now
m_ClassificationOutputPanel.setPreferredSize(new Dimension(300,
m_ClassificationOutputPanel.getHeight()));
jd.pack();
// final List pluginMetrics =
// AbstractEvaluationMetric
// .getPluginMetrics();
final JButton editEvalMetrics = new JButton("Evaluation metrics...");
JPanel evalP = new JPanel();
evalP.setLayout(new BorderLayout());
evalP.setBorder(BorderFactory.createEmptyBorder(5, 5, 5, 5));
evalP.add(editEvalMetrics, BorderLayout.CENTER);
editEvalMetrics
.setToolTipText("Enable/disable output of specific evaluation metrics");
moreOptionsPanel.add(evalP);
editEvalMetrics.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
EvaluationMetricSelectionDialog esd = new EvaluationMetricSelectionDialog(
jd, m_selectedEvalMetrics);
esd.setLocation(m_MoreOptions.getLocationOnScreen());
esd.pack();
esd.setVisible(true);
m_selectedEvalMetrics = esd.getSelectedEvalMetrics();
}
});
jd.setLocation(m_MoreOptions.getLocationOnScreen());
jd.setVisible(true);
}
});
// Layout the GUI
JPanel p1 = new JPanel();
p1.setBorder(BorderFactory.createCompoundBorder(
BorderFactory.createTitledBorder("Classifier"),
BorderFactory.createEmptyBorder(0, 5, 5, 5)));
p1.setLayout(new BorderLayout());
p1.add(m_CEPanel, BorderLayout.NORTH);
JPanel p2 = new JPanel();
GridBagLayout gbL = new GridBagLayout();
p2.setLayout(gbL);
p2.setBorder(BorderFactory.createCompoundBorder(
BorderFactory.createTitledBorder("Test options"),
BorderFactory.createEmptyBorder(0, 5, 5, 5)));
GridBagConstraints gbC = new GridBagConstraints();
gbC.anchor = GridBagConstraints.WEST;
gbC.gridy = 0;
gbC.gridx = 0;
gbL.setConstraints(m_TrainBut, gbC);
p2.add(m_TrainBut);
gbC = new GridBagConstraints();
gbC.anchor = GridBagConstraints.WEST;
gbC.gridy = 1;
gbC.gridx = 0;
gbL.setConstraints(m_TestSplitBut, gbC);
p2.add(m_TestSplitBut);
gbC = new GridBagConstraints();
gbC.anchor = GridBagConstraints.EAST;
gbC.fill = GridBagConstraints.HORIZONTAL;
gbC.gridy = 1;
gbC.gridx = 1;
gbC.gridwidth = 2;
gbC.insets = new Insets(2, 10, 2, 0);
gbL.setConstraints(m_SetTestBut, gbC);
p2.add(m_SetTestBut);
gbC = new GridBagConstraints();
gbC.anchor = GridBagConstraints.WEST;
gbC.gridy = 2;
gbC.gridx = 0;
gbL.setConstraints(m_CVBut, gbC);
p2.add(m_CVBut);
gbC = new GridBagConstraints();
gbC.anchor = GridBagConstraints.EAST;
gbC.fill = GridBagConstraints.HORIZONTAL;
gbC.gridy = 2;
gbC.gridx = 1;
gbC.insets = new Insets(2, 10, 2, 10);
gbL.setConstraints(m_CVLab, gbC);
p2.add(m_CVLab);
gbC = new GridBagConstraints();
gbC.anchor = GridBagConstraints.EAST;
gbC.fill = GridBagConstraints.HORIZONTAL;
gbC.gridy = 2;
gbC.gridx = 2;
gbC.weightx = 100;
gbC.ipadx = 20;
gbL.setConstraints(m_CVText, gbC);
p2.add(m_CVText);
gbC = new GridBagConstraints();
gbC.anchor = GridBagConstraints.WEST;
gbC.gridy = 3;
gbC.gridx = 0;
gbL.setConstraints(m_PercentBut, gbC);
p2.add(m_PercentBut);
gbC = new GridBagConstraints();
gbC.anchor = GridBagConstraints.EAST;
gbC.fill = GridBagConstraints.HORIZONTAL;
gbC.gridy = 3;
gbC.gridx = 1;
gbC.insets = new Insets(2, 10, 2, 10);
gbL.setConstraints(m_PercentLab, gbC);
p2.add(m_PercentLab);
gbC = new GridBagConstraints();
gbC.anchor = GridBagConstraints.EAST;
gbC.fill = GridBagConstraints.HORIZONTAL;
gbC.gridy = 3;
gbC.gridx = 2;
gbC.weightx = 100;
gbC.ipadx = 20;
gbL.setConstraints(m_PercentText, gbC);
p2.add(m_PercentText);
gbC = new GridBagConstraints();
gbC.anchor = GridBagConstraints.WEST;
gbC.fill = GridBagConstraints.HORIZONTAL;
gbC.gridy = 4;
gbC.gridx = 0;
gbC.weightx = 100;
gbC.gridwidth = 3;
gbC.insets = new Insets(3, 0, 1, 0);
gbL.setConstraints(m_MoreOptions, gbC);
p2.add(m_MoreOptions);
// Any launcher plugins?
Vector pluginsVector = GenericObjectEditor
.getClassnames(ClassifierPanelLaunchHandlerPlugin.class.getName());
JButton pluginBut = null;
if (pluginsVector.size() == 1) {
try {
// Display as a single button
String className = pluginsVector.elementAt(0);
final ClassifierPanelLaunchHandlerPlugin plugin = (ClassifierPanelLaunchHandlerPlugin) Class
.forName(className).newInstance();
if (plugin != null) {
plugin.setClassifierPanel(this);
pluginBut = new JButton(plugin.getLaunchCommand());
pluginBut.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
plugin.launch();
}
});
}
} catch (Exception ex) {
ex.printStackTrace();
}
} else if (pluginsVector.size() > 1) {
// make a popu menu
int okPluginCount = 0;
final java.awt.PopupMenu pluginPopup = new java.awt.PopupMenu();
for (int i = 0; i < pluginsVector.size(); i++) {
String className = (pluginsVector.elementAt(i));
try {
final ClassifierPanelLaunchHandlerPlugin plugin = (ClassifierPanelLaunchHandlerPlugin) Class
.forName(className).newInstance();
if (plugin == null) {
continue;
}
okPluginCount++;
plugin.setClassifierPanel(this);
java.awt.MenuItem popI = new java.awt.MenuItem(
plugin.getLaunchCommand());
popI.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
// pluginPopup.setVisible(false);
plugin.launch();
}
});
pluginPopup.add(popI);
} catch (Exception ex) {
ex.printStackTrace();
}
}
if (okPluginCount > 0) {
pluginBut = new JButton("Launchers...");
final JButton copyB = pluginBut;
copyB.add(pluginPopup);
pluginBut.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
pluginPopup.show(copyB, 0, 0);
}
});
} else {
pluginBut = null;
}
}
JPanel buttons = new JPanel();
buttons.setLayout(new GridLayout(2, 2));
buttons.add(m_ClassCombo);
m_ClassCombo.setBorder(BorderFactory.createEmptyBorder(5, 5, 5, 5));
JPanel ssButs = new JPanel();
ssButs.setBorder(BorderFactory.createEmptyBorder(5, 5, 5, 5));
if (pluginBut == null) {
ssButs.setLayout(new GridLayout(1, 2, 5, 5));
} else {
ssButs.setLayout(new FlowLayout(FlowLayout.LEFT));
}
ssButs.add(m_StartBut);
ssButs.add(m_StopBut);
if (pluginBut != null) {
ssButs.add(pluginBut);
}
buttons.add(ssButs);
JPanel p3 = new JPanel();
p3.setBorder(BorderFactory.createTitledBorder("Classifier output"));
p3.setLayout(new BorderLayout());
final JScrollPane js = new JScrollPane(m_OutText);
p3.add(js, BorderLayout.CENTER);
js.getViewport().addChangeListener(new ChangeListener() {
private int lastHeight;
@Override
public void stateChanged(ChangeEvent e) {
JViewport vp = (JViewport) e.getSource();
int h = vp.getViewSize().height;
if (h != lastHeight) { // i.e. an addition not just a user scrolling
lastHeight = h;
int x = h - vp.getExtentSize().height;
vp.setViewPosition(new Point(0, x));
}
}
});
JPanel mondo = new JPanel();
gbL = new GridBagLayout();
mondo.setLayout(gbL);
gbC = new GridBagConstraints();
// gbC.anchor = GridBagConstraints.WEST;
gbC.fill = GridBagConstraints.HORIZONTAL;
gbC.gridy = 0;
gbC.gridx = 0;
gbL.setConstraints(p2, gbC);
mondo.add(p2);
gbC = new GridBagConstraints();
gbC.anchor = GridBagConstraints.NORTH;
gbC.fill = GridBagConstraints.HORIZONTAL;
gbC.gridy = 1;
gbC.gridx = 0;
gbL.setConstraints(buttons, gbC);
mondo.add(buttons);
gbC = new GridBagConstraints();
// gbC.anchor = GridBagConstraints.NORTH;
gbC.fill = GridBagConstraints.BOTH;
gbC.gridy = 2;
gbC.gridx = 0;
gbC.weightx = 0;
gbL.setConstraints(m_History, gbC);
mondo.add(m_History);
gbC = new GridBagConstraints();
gbC.fill = GridBagConstraints.BOTH;
gbC.gridy = 0;
gbC.gridx = 1;
gbC.gridheight = 3;
gbC.weightx = 100;
gbC.weighty = 100;
gbL.setConstraints(p3, gbC);
mondo.add(p3);
setLayout(new BorderLayout());
add(p1, BorderLayout.NORTH);
add(mondo, BorderLayout.CENTER);
}
/**
* Updates the enabled status of the input fields and labels.
*/
protected void updateRadioLinks() {
m_SetTestBut.setEnabled(m_TestSplitBut.isSelected());
if ((m_SetTestFrame != null) && (!m_TestSplitBut.isSelected())) {
m_SetTestFrame.setVisible(false);
}
m_CVText.setEnabled(m_CVBut.isSelected());
m_CVLab.setEnabled(m_CVBut.isSelected());
m_PercentText.setEnabled(m_PercentBut.isSelected());
m_PercentLab.setEnabled(m_PercentBut.isSelected());
}
/**
* Sets the Logger to receive informational messages.
*
* @param newLog the Logger that will now get info messages
*/
@Override
public void setLog(Logger newLog) {
m_Log = newLog;
}
/**
* Tells the panel to use a new set of instances.
*
* @param inst a set of Instances
*/
@Override
public void setInstances(Instances inst) {
m_Instances = inst;
String[] attribNames = new String[m_Instances.numAttributes()];
for (int i = 0; i < attribNames.length; i++) {
String type = "(" + Attribute.typeToStringShort(m_Instances.attribute(i))
+ ") ";
attribNames[i] = type + m_Instances.attribute(i).name();
}
m_ClassCombo.setModel(new DefaultComboBoxModel(attribNames));
if (attribNames.length > 0) {
if (inst.classIndex() == -1) {
m_ClassCombo.setSelectedIndex(attribNames.length - 1);
} else {
m_ClassCombo.setSelectedIndex(inst.classIndex());
}
m_ClassCombo.setEnabled(true);
m_StartBut.setEnabled(m_RunThread == null);
m_StopBut.setEnabled(m_RunThread != null);
} else {
m_StartBut.setEnabled(false);
m_StopBut.setEnabled(false);
}
}
/**
* Sets the user test set. Information about the current test set is displayed
* in an InstanceSummaryPanel and the user is given the ability to load
* another set from a file or url.
*
*/
protected void setTestSet() {
if (m_SetTestFrame == null) {
final SetInstancesPanel sp = new SetInstancesPanel(true, true,
m_Explorer.getPreprocessPanel().m_FileChooser);
if (m_TestLoader != null) {
try {
if (m_TestLoader.getStructure() != null) {
sp.setInstances(m_TestLoader.getStructure());
}
} catch (Exception ex) {
ex.printStackTrace();
}
}
sp.addPropertyChangeListener(new PropertyChangeListener() {
@Override
public void propertyChange(PropertyChangeEvent e) {
m_TestLoader = sp.getLoader();
m_TestClassIndex = sp.getClassIndex();
}
});
// Add propertychangelistener to update m_TestLoader whenever
// it changes in the settestframe
m_SetTestFrame = new JFrame("Test Instances");
sp.setParentFrame(m_SetTestFrame); // enable Close-Button
m_SetTestFrame.getContentPane().setLayout(new BorderLayout());
m_SetTestFrame.getContentPane().add(sp, BorderLayout.CENTER);
m_SetTestFrame.pack();
}
m_SetTestFrame.setVisible(true);
}
/**
* outputs the header for the predictions on the data.
*
* @param outBuff the buffer to add the output to
* @param classificationOutput for generating the classification output
* @param title the title to print
*/
protected void printPredictionsHeader(StringBuffer outBuff,
AbstractOutput classificationOutput, String title) {
if (classificationOutput.generatesOutput()) {
outBuff.append("=== Predictions on " + title + " ===\n\n");
}
classificationOutput.printHeader();
}
protected static Evaluation setupEval(Evaluation eval, Classifier classifier,
Instances inst, CostMatrix costMatrix,
ClassifierErrorsPlotInstances plotInstances,
AbstractOutput classificationOutput, boolean onlySetPriors)
throws Exception {
if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) {
Instances mappedClassifierHeader = ((weka.classifiers.misc.InputMappedClassifier) classifier)
.getModelHeader(new Instances(inst, 0));
if (classificationOutput != null) {
classificationOutput.setHeader(mappedClassifierHeader);
}
if (!onlySetPriors) {
if (costMatrix != null) {
eval = new Evaluation(new Instances(mappedClassifierHeader, 0),
costMatrix);
} else {
eval = new Evaluation(new Instances(mappedClassifierHeader, 0));
}
}
if (!eval.getHeader().equalHeaders(inst)) {
// When the InputMappedClassifier is loading a model,
// we need to make a new dataset that maps the training instances to
// the structure expected by the mapped classifier - this is only
// to ensure that the structure and priors computed by
// evaluation object is correct with respect to the mapped classifier
Instances mappedClassifierDataset = ((weka.classifiers.misc.InputMappedClassifier) classifier)
.getModelHeader(new Instances(mappedClassifierHeader, 0));
for (int zz = 0; zz < inst.numInstances(); zz++) {
Instance mapped = ((weka.classifiers.misc.InputMappedClassifier) classifier)
.constructMappedInstance(inst.instance(zz));
mappedClassifierDataset.add(mapped);
}
eval.setPriors(mappedClassifierDataset);
if (!onlySetPriors) {
if (plotInstances != null) {
plotInstances.setInstances(mappedClassifierDataset);
plotInstances.setClassifier(classifier);
/*
* int mappedClass =
* ((weka.classifiers.misc.InputMappedClassifier)classifier
* ).getMappedClassIndex(); System.err.println("Mapped class index "
* + mappedClass);
*/
plotInstances.setClassIndex(mappedClassifierDataset.classIndex());
plotInstances.setEvaluation(eval);
}
}
} else {
eval.setPriors(inst);
if (!onlySetPriors) {
if (plotInstances != null) {
plotInstances.setInstances(inst);
plotInstances.setClassifier(classifier);
plotInstances.setClassIndex(inst.classIndex());
plotInstances.setEvaluation(eval);
}
}
}
} else {
eval.setPriors(inst);
if (!onlySetPriors) {
if (plotInstances != null) {
plotInstances.setInstances(inst);
plotInstances.setClassifier(classifier);
plotInstances.setClassIndex(inst.classIndex());
plotInstances.setEvaluation(eval);
}
}
}
return eval;
}
/**
* Starts running the currently configured classifier with the current
* settings. This is run in a separate thread, and will only start if there is
* no classifier already running. The classifier output is sent to the results
* history panel.
*/
protected void startClassifier() {
if (m_RunThread == null) {
synchronized (this) {
m_StartBut.setEnabled(false);
m_StopBut.setEnabled(true);
}
m_RunThread = new Thread() {
@Override
public void run() {
m_CEPanel.addToHistory();
// Copy the current state of things
m_Log.statusMessage("Setting up...");
CostMatrix costMatrix = null;
Instances inst = new Instances(m_Instances);
DataSource source = null;
Instances userTestStructure = null;
ClassifierErrorsPlotInstances plotInstances = null;
// for timing
long trainTimeStart = 0, trainTimeElapsed = 0;
long testTimeStart = 0, testTimeElapsed = 0;
try {
if (m_TestLoader != null && m_TestLoader.getStructure() != null) {
m_TestLoader.reset();
source = new DataSource(m_TestLoader);
userTestStructure = source.getStructure();
userTestStructure.setClassIndex(m_TestClassIndex);
}
} catch (Exception ex) {
ex.printStackTrace();
}
if (m_EvalWRTCostsBut.isSelected()) {
costMatrix = new CostMatrix(
(CostMatrix) m_CostMatrixEditor.getValue());
}
boolean outputModel = m_OutputModelBut.isSelected();
boolean outputConfusion = m_OutputConfusionBut.isSelected();
boolean outputPerClass = m_OutputPerClassBut.isSelected();
boolean outputSummary = true;
boolean outputEntropy = m_OutputEntropyBut.isSelected();
boolean saveVis = m_StorePredictionsBut.isSelected();
boolean outputPredictionsText = (m_ClassificationOutputEditor
.getValue().getClass() != Null.class);
String grph = null;
int testMode = 0;
int numFolds = 10;
double percent = 66;
int classIndex = m_ClassCombo.getSelectedIndex();
inst.setClassIndex(classIndex);
Classifier classifier = (Classifier) m_ClassifierEditor.getValue();
Classifier template = null;
try {
template = AbstractClassifier.makeCopy(classifier);
} catch (Exception ex) {
m_Log.logMessage("Problem copying classifier: " + ex.getMessage());
}
Classifier fullClassifier = null;
StringBuffer outBuff = new StringBuffer();
AbstractOutput classificationOutput = null;
if (outputPredictionsText) {
classificationOutput = (AbstractOutput) m_ClassificationOutputEditor
.getValue();
Instances header = new Instances(inst, 0);
header.setClassIndex(classIndex);
classificationOutput.setHeader(header);
classificationOutput.setBuffer(outBuff);
}
String name = (new SimpleDateFormat("HH:mm:ss - "))
.format(new Date());
String cname = "";
String cmd = "";
Evaluation eval = null;
try {
if (m_CVBut.isSelected()) {
testMode = 1;
numFolds = Integer.parseInt(m_CVText.getText());
if (numFolds <= 1) {
throw new Exception("Number of folds must be greater than 1");
}
} else if (m_PercentBut.isSelected()) {
testMode = 2;
percent = Double.parseDouble(m_PercentText.getText());
if ((percent <= 0) || (percent >= 100)) {
throw new Exception("Percentage must be between 0 and 100");
}
} else if (m_TrainBut.isSelected()) {
testMode = 3;
} else if (m_TestSplitBut.isSelected()) {
testMode = 4;
// Check the test instance compatibility
if (source == null) {
throw new Exception("No user test set has been specified");
}
if (!(classifier instanceof weka.classifiers.misc.InputMappedClassifier)) {
if (!inst.equalHeaders(userTestStructure)) {
boolean wrapClassifier = false;
if (!Utils
.getDontShowDialog("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier")) {
JCheckBox dontShow = new JCheckBox(
"Do not show this message again");
Object[] stuff = new Object[2];
stuff[0] = "Train and test set are not compatible.\n"
+ "Would you like to automatically wrap the classifier in\n"
+ "an \"InputMappedClassifier\" before proceeding?.\n";
stuff[1] = dontShow;
int result = JOptionPane.showConfirmDialog(
ClassifierPanel.this, stuff, "ClassifierPanel",
JOptionPane.YES_OPTION);
if (result == JOptionPane.YES_OPTION) {
wrapClassifier = true;
}
if (dontShow.isSelected()) {
String response = (wrapClassifier) ? "yes" : "no";
Utils
.setDontShowDialogResponse(
"weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier",
response);
}
} else {
// What did the user say - do they want to autowrap or not?
String response = Utils
.getDontShowDialogResponse("weka.gui.explorer.ClassifierPanel.AutoWrapInInputMappedClassifier");
if (response != null && response.equalsIgnoreCase("yes")) {
wrapClassifier = true;
}
}
if (wrapClassifier) {
weka.classifiers.misc.InputMappedClassifier temp = new weka.classifiers.misc.InputMappedClassifier();
// pass on the known test structure so that we get the
// correct mapping report from the toString() method
// of InputMappedClassifier
temp.setClassifier(classifier);
temp.setTestStructure(userTestStructure);
classifier = temp;
} else {
throw new Exception(
"Train and test set are not compatible\n"
+ inst.equalHeadersMsg(userTestStructure));
}
}
}
} else {
throw new Exception("Unknown test mode");
}
cname = classifier.getClass().getName();
if (cname.startsWith("weka.classifiers.")) {
name += cname.substring("weka.classifiers.".length());
} else {
name += cname;
}
cmd = classifier.getClass().getName();
if (classifier instanceof OptionHandler) {
cmd += " "
+ Utils.joinOptions(((OptionHandler) classifier).getOptions());
}
// set up the structure of the plottable instances for
// visualization
plotInstances = ExplorerDefaults.getClassifierErrorsPlotInstances();
plotInstances.setInstances(inst);
plotInstances.setClassifier(classifier);
plotInstances.setClassIndex(inst.classIndex());
plotInstances.setSaveForVisualization(saveVis);
plotInstances
.setPointSizeProportionalToMargin(m_errorPlotPointSizeProportionalToMargin
.isSelected());
// Output some header information
m_Log.logMessage("Started " + cname);
m_Log.logMessage("Command: " + cmd);
if (m_Log instanceof TaskLogger) {
((TaskLogger) m_Log).taskStarted();
}
outBuff.append("=== Run information ===\n\n");
outBuff.append("Scheme: " + cname);
if (classifier instanceof OptionHandler) {
String[] o = ((OptionHandler) classifier).getOptions();
outBuff.append(" " + Utils.joinOptions(o));
}
outBuff.append("\n");
outBuff.append("Relation: " + inst.relationName() + '\n');
outBuff.append("Instances: " + inst.numInstances() + '\n');
outBuff.append("Attributes: " + inst.numAttributes() + '\n');
if (inst.numAttributes() < 100) {
for (int i = 0; i < inst.numAttributes(); i++) {
outBuff.append(" " + inst.attribute(i).name()
+ '\n');
}
} else {
outBuff.append(" [list of attributes omitted]\n");
}
outBuff.append("Test mode: ");
switch (testMode) {
case 3: // Test on training
outBuff.append("evaluate on training data\n");
break;
case 1: // CV mode
outBuff.append("" + numFolds + "-fold cross-validation\n");
break;
case 2: // Percent split
outBuff.append("split " + percent + "% train, remainder test\n");
break;
case 4: // Test on user split
if (source.isIncremental()) {
outBuff.append("user supplied test set: "
+ " size unknown (reading incrementally)\n");
} else {
outBuff.append("user supplied test set: "
+ source.getDataSet().numInstances() + " instances\n");
}
break;
}
if (costMatrix != null) {
outBuff.append("Evaluation cost matrix:\n")
.append(costMatrix.toString()).append("\n");
}
outBuff.append("\n");
m_History.addResult(name, outBuff);
m_History.setSingle(name);
// Build the model and output it.
if (outputModel || (testMode == 3) || (testMode == 4)) {
m_Log.statusMessage("Building model on training data...");
trainTimeStart = System.currentTimeMillis();
classifier.buildClassifier(inst);
trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
}
if (outputModel) {
outBuff
.append("=== Classifier model (full training set) ===\n\n");
outBuff.append(classifier.toString() + "\n");
outBuff.append("\nTime taken to build model: "
+ Utils.doubleToString(trainTimeElapsed / 1000.0, 2)
+ " seconds\n\n");
m_History.updateResult(name);
if (classifier instanceof Drawable) {
grph = null;
try {
grph = ((Drawable) classifier).graph();
} catch (Exception ex) {
}
}
// copy full model for output
SerializedObject so = new SerializedObject(classifier);
fullClassifier = (Classifier) so.getObject();
}
switch (testMode) {
case 3: // Test on training
m_Log.statusMessage("Evaluating on training data...");
eval = new Evaluation(inst, costMatrix);
// make adjustments if the classifier is an InputMappedClassifier
eval = setupEval(eval, classifier, inst, costMatrix,
plotInstances, classificationOutput, false);
eval.setMetricsToDisplay(m_selectedEvalMetrics);
// plotInstances.setEvaluation(eval);
plotInstances.setUp();
if (outputPredictionsText) {
printPredictionsHeader(outBuff, classificationOutput,
"training set");
}
testTimeStart = System.currentTimeMillis();
if (classifier instanceof BatchPredictor) {
Instances toPred = new Instances(inst);
for (int i = 0; i < toPred.numInstances(); i++) {
toPred.instance(i).setClassMissing();
}
double[][] predictions = ((BatchPredictor) classifier)
.distributionsForInstances(toPred);
plotInstances.process(inst, predictions, eval);
if (outputPredictionsText) {
for (int jj = 0; jj < inst.numInstances(); jj++) {
classificationOutput.printClassification(predictions[jj],
inst.instance(jj), jj);
}
}
} else {
for (int jj = 0; jj < inst.numInstances(); jj++) {
plotInstances.process(inst.instance(jj), classifier, eval);
if (outputPredictionsText) {
classificationOutput.printClassification(classifier,
inst.instance(jj), jj);
}
if ((jj % 100) == 0) {
m_Log
.statusMessage("Evaluating on training data. Processed "
+ jj + " instances...");
}
}
}
testTimeElapsed = System.currentTimeMillis() - testTimeStart;
if (outputPredictionsText) {
classificationOutput.printFooter();
}
if (outputPredictionsText
&& classificationOutput.generatesOutput()) {
outBuff.append("\n");
}
outBuff.append("=== Evaluation on training set ===\n");
break;
case 1: // CV mode
m_Log.statusMessage("Randomizing instances...");
int rnd = 1;
try {
rnd = Integer.parseInt(m_RandomSeedText.getText().trim());
// System.err.println("Using random seed "+rnd);
} catch (Exception ex) {
m_Log.logMessage("Trouble parsing random seed value");
rnd = 1;
}
Random random = new Random(rnd);
inst.randomize(random);
if (inst.attribute(classIndex).isNominal()) {
m_Log.statusMessage("Stratifying instances...");
inst.stratify(numFolds);
}
eval = new Evaluation(inst, costMatrix);
// make adjustments if the classifier is an InputMappedClassifier
eval = setupEval(eval, classifier, inst, costMatrix,
plotInstances, classificationOutput, false);
eval.setMetricsToDisplay(m_selectedEvalMetrics);
// plotInstances.setEvaluation(eval);
plotInstances.setUp();
if (outputPredictionsText) {
printPredictionsHeader(outBuff, classificationOutput,
"test data");
}
// Make some splits and do a CV
for (int fold = 0; fold < numFolds; fold++) {
m_Log.statusMessage("Creating splits for fold " + (fold + 1)
+ "...");
Instances train = inst.trainCV(numFolds, fold, random);
// make adjustments if the classifier is an
// InputMappedClassifier
eval = setupEval(eval, classifier, train, costMatrix,
plotInstances, classificationOutput, true);
eval.setMetricsToDisplay(m_selectedEvalMetrics);
// eval.setPriors(train);
m_Log.statusMessage("Building model for fold " + (fold + 1)
+ "...");
Classifier current = null;
try {
current = AbstractClassifier.makeCopy(template);
} catch (Exception ex) {
m_Log.logMessage("Problem copying classifier: "
+ ex.getMessage());
}
current.buildClassifier(train);
Instances test = inst.testCV(numFolds, fold);
m_Log.statusMessage("Evaluating model for fold " + (fold + 1)
+ "...");
if (classifier instanceof BatchPredictor) {
Instances toPred = new Instances(test);
for (int i = 0; i < toPred.numInstances(); i++) {
toPred.instance(i).setClassMissing();
}
double[][] predictions = ((BatchPredictor) current)
.distributionsForInstances(toPred);
plotInstances.process(test, predictions, eval);
if (outputPredictionsText) {
for (int jj = 0; jj < test.numInstances(); jj++) {
classificationOutput.printClassification(predictions[jj],
test.instance(jj), jj);
}
}
} else {
for (int jj = 0; jj < test.numInstances(); jj++) {
plotInstances.process(test.instance(jj), current, eval);
if (outputPredictionsText) {
classificationOutput.printClassification(current,
test.instance(jj), jj);
}
}
}
}
if (outputPredictionsText) {
classificationOutput.printFooter();
}
if (outputPredictionsText) {
outBuff.append("\n");
}
if (inst.attribute(classIndex).isNominal()) {
outBuff.append("=== Stratified cross-validation ===\n");
} else {
outBuff.append("=== Cross-validation ===\n");
}
break;
case 2: // Percent split
if (!m_PreserveOrderBut.isSelected()) {
m_Log.statusMessage("Randomizing instances...");
try {
rnd = Integer.parseInt(m_RandomSeedText.getText().trim());
} catch (Exception ex) {
m_Log.logMessage("Trouble parsing random seed value");
rnd = 1;
}
inst.randomize(new Random(rnd));
}
int trainSize = (int) Math.round(inst.numInstances() * percent
/ 100);
int testSize = inst.numInstances() - trainSize;
Instances train = new Instances(inst, 0, trainSize);
Instances test = new Instances(inst, trainSize, testSize);
m_Log.statusMessage("Building model on training split ("
+ trainSize + " instances)...");
Classifier current = null;
try {
current = AbstractClassifier.makeCopy(template);
} catch (Exception ex) {
m_Log.logMessage("Problem copying classifier: "
+ ex.getMessage());
}
current.buildClassifier(train);
eval = new Evaluation(train, costMatrix);
// make adjustments if the classifier is an InputMappedClassifier
eval = setupEval(eval, classifier, train, costMatrix,
plotInstances, classificationOutput, false);
eval.setMetricsToDisplay(m_selectedEvalMetrics);
// plotInstances.setEvaluation(eval);
plotInstances.setUp();
m_Log.statusMessage("Evaluating on test split...");
if (outputPredictionsText) {
printPredictionsHeader(outBuff, classificationOutput,
"test split");
}
testTimeStart = System.currentTimeMillis();
if (classifier instanceof BatchPredictor) {
Instances toPred = new Instances(test);
for (int i = 0; i < toPred.numInstances(); i++) {
toPred.instance(i).setClassMissing();
}
double[][] predictions = ((BatchPredictor) current)
.distributionsForInstances(toPred);
plotInstances.process(test, predictions, eval);
if (outputPredictionsText) {
for (int jj = 0; jj < test.numInstances(); jj++) {
classificationOutput.printClassification(predictions[jj],
test.instance(jj), jj);
}
}
} else {
for (int jj = 0; jj < test.numInstances(); jj++) {
plotInstances.process(test.instance(jj), current, eval);
if (outputPredictionsText) {
classificationOutput.printClassification(current,
test.instance(jj), jj);
}
if ((jj % 100) == 0) {
m_Log.statusMessage("Evaluating on test split. Processed "
+ jj + " instances...");
}
}
}
testTimeElapsed = System.currentTimeMillis() - testTimeStart;
if (outputPredictionsText) {
classificationOutput.printFooter();
}
if (outputPredictionsText) {
outBuff.append("\n");
}
outBuff.append("=== Evaluation on test split ===\n");
break;
case 4: // Test on user split
m_Log.statusMessage("Evaluating on test data...");
eval = new Evaluation(inst, costMatrix);
// make adjustments if the classifier is an InputMappedClassifier
eval = setupEval(eval, classifier, inst, costMatrix,
plotInstances, classificationOutput, false);
eval.setMetricsToDisplay(m_selectedEvalMetrics);
// plotInstances.setEvaluation(eval);
plotInstances.setUp();
if (outputPredictionsText) {
printPredictionsHeader(outBuff, classificationOutput,
"test set");
}
Instance instance;
int jj = 0;
Instances batchInst = null;
int batchSize = 100;
if (classifier instanceof BatchPredictor) {
batchInst = new Instances(userTestStructure, 0);
String batchSizeS = ((BatchPredictor) classifier)
.getBatchSize();
if (batchSizeS != null && batchSizeS.length() > 0) {
try {
batchSizeS = Environment.getSystemWide().substitute(
batchSizeS);
} catch (Exception ex) {
}
try {
batchSize = Integer.parseInt(batchSizeS);
} catch (NumberFormatException ex) {
// just go with the default
}
}
}
testTimeStart = System.currentTimeMillis();
while (source.hasMoreElements(userTestStructure)) {
instance = source.nextElement(userTestStructure);
if (classifier instanceof BatchPredictor) {
batchInst.add(instance);
if (batchInst.numInstances() == batchSize) {
Instances toPred = new Instances(batchInst);
for (int i = 0; i < toPred.numInstances(); i++) {
toPred.instance(i).setClassMissing();
}
double[][] predictions = ((BatchPredictor) classifier)
.distributionsForInstances(toPred);
plotInstances.process(batchInst, predictions, eval);
if (outputPredictionsText) {
for (int kk = 0; kk < batchInst.numInstances(); kk++) {
classificationOutput.printClassification(
predictions[kk], batchInst.instance(kk), kk);
}
}
jj += batchInst.numInstances();
m_Log.statusMessage("Evaluating on test data. Processed "
+ jj + " instances...");
batchInst.delete();
}
} else {
plotInstances.process(instance, classifier, eval);
if (outputPredictionsText) {
classificationOutput.printClassification(classifier,
instance, jj);
}
if ((++jj % 100) == 0) {
m_Log.statusMessage("Evaluating on test data. Processed "
+ jj + " instances...");
}
}
}
if (classifier instanceof BatchPredictor
&& batchInst.numInstances() > 0) {
// finish the last batch
Instances toPred = new Instances(batchInst);
for (int i = 0; i < toPred.numInstances(); i++) {
toPred.instance(i).setClassMissing();
}
double[][] predictions = ((BatchPredictor) classifier)
.distributionsForInstances(toPred);
plotInstances.process(batchInst, predictions, eval);
if (outputPredictionsText) {
for (int kk = 0; kk < batchInst.numInstances(); kk++) {
classificationOutput.printClassification(predictions[kk],
batchInst.instance(kk), kk);
}
}
}
testTimeElapsed = System.currentTimeMillis() - testTimeStart;
if (outputPredictionsText) {
classificationOutput.printFooter();
}
if (outputPredictionsText) {
outBuff.append("\n");
}
outBuff.append("=== Evaluation on test set ===\n");
break;
default:
throw new Exception("Test mode not implemented");
}
if (testMode != 1) {
String mode = "";
if (testMode == 2) {
mode = "training split";
} else if (testMode == 3) {
mode = "training data";
} else if (testMode == 4) {
mode = "supplied test set";
}
outBuff.append("\nTime taken to test model on " + mode + ": "
+ Utils.doubleToString(testTimeElapsed / 1000.0, 2)
+ " seconds\n\n");
}
if (outputSummary) {
outBuff.append(eval.toSummaryString(outputEntropy) + "\n");
}
if (inst.attribute(classIndex).isNominal()) {
if (outputPerClass) {
outBuff.append(eval.toClassDetailsString() + "\n");
}
if (outputConfusion) {
outBuff.append(eval.toMatrixString() + "\n");
}
}
if ((fullClassifier instanceof Sourcable)
&& m_OutputSourceCode.isSelected()) {
outBuff.append("=== Source code ===\n\n");
outBuff.append(Evaluation.wekaStaticWrapper(
((Sourcable) fullClassifier), m_SourceCodeClass.getText()));
}
m_History.updateResult(name);
m_Log.logMessage("Finished " + cname);
m_Log.statusMessage("OK");
} catch (Exception ex) {
ex.printStackTrace();
m_Log.logMessage(ex.getMessage());
JOptionPane.showMessageDialog(ClassifierPanel.this,
"Problem evaluating classifier:\n" + ex.getMessage(),
"Evaluate classifier", JOptionPane.ERROR_MESSAGE);
m_Log.statusMessage("Problem evaluating classifier");
} finally {
try {
if (!saveVis && outputModel) {
ArrayList
© 2015 - 2024 Weber Informatics LLC | Privacy Policy