weka.gui.beans.SerializedModelSaver Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of weka-stable Show documentation
Show all versions of weka-stable Show documentation
The Waikato Environment for Knowledge Analysis (WEKA), a machine
learning workbench. This is the stable version. Apart from bugfixes, this version
does not receive any other updates.
/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* SerializedModelSaver.java
* Copyright (C) 2008 University of Waikato, Hamilton, New Zealand
*
*/
package weka.gui.beans;
import java.awt.BorderLayout;
import java.beans.EventSetDescriptor;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Vector;
import javax.swing.JPanel;
import weka.core.Environment;
import weka.core.EnvironmentHandler;
import weka.core.Instances;
import weka.core.Tag;
import weka.core.Utils;
import weka.core.xml.KOML;
import weka.core.xml.XStream;
/**
* A bean that saves serialized models
*
* @author Mark Hall (mhall{[at]}pentaho{[dot]}org
* @version $Revision: 10534 $
*/
public class SerializedModelSaver extends JPanel implements BeanCommon,
Visible, BatchClassifierListener, IncrementalClassifierListener,
BatchClustererListener, EnvironmentHandler, Serializable {
/** for serialization */
private static final long serialVersionUID = 3956528599473814287L;
/**
* Default visual for data sources
*/
protected BeanVisual m_visual = new BeanVisual("AbstractDataSink",
BeanVisual.ICON_PATH + "SerializedModelSaver.gif", BeanVisual.ICON_PATH
+ "SerializedModelSaver_animated.gif");
/**
* Non null if this object is a target for any events. Provides for the
* simplest case when only one incomming connection is allowed.
*/
protected Object m_listenee = null;
/**
* The log for this bean
*/
protected transient weka.gui.Logger m_logger = null;
/**
* The prefix for the file name (model + training set info will be appended)
*/
private String m_filenamePrefix = "";
/**
* The directory to hold the saved model(s)
*/
private File m_directory = new File(System.getProperty("user.dir"));
/**
* File format stuff
*/
private Tag m_fileFormat;
public final static int BINARY = 0;
public final static int KOMLV = 1;
public final static int XSTREAM = 2;
/** the extension for serialized models (binary Java serialization) */
public final static String FILE_EXTENSION = "model";
/**
* relative path for the directory (relative to the user.dir (startup
* directory))?
*/
private boolean m_useRelativePath = false;
/**
* Available file formats. Reflection is used to check if classes are
* available for deep object serialization to XML
*/
public static ArrayList s_fileFormatsAvailable;
static {
s_fileFormatsAvailable = new ArrayList();
s_fileFormatsAvailable.add(new Tag(BINARY, Messages.getInstance()
.getString("SerializedModelSaver_FileFormatsAvailable_Text_First")
+ FILE_EXTENSION
+ Messages.getInstance().getString(
"SerializedModelSaver_FileFormatsAvailable_Text_Second"), "", false));
if (KOML.isPresent()) {
s_fileFormatsAvailable.add(new Tag(KOMLV, Messages.getInstance()
.getString("SerializedModelSaver_FileFormatsAvailable_Text_Third")
+ KOML.FILE_EXTENSION
+ FILE_EXTENSION
+ Messages.getInstance().getString(
"SerializedModelSaver_FileFormatsAvailable_Text_Fourth"), "", false));
}
if (XStream.isPresent()) {
s_fileFormatsAvailable.add(new Tag(XSTREAM, Messages.getInstance()
.getString("SerializedModelSaver_FileFormatsAvailable_Text_Fifth")
+ XStream.FILE_EXTENSION
+ FILE_EXTENSION
+ Messages.getInstance().getString(
"SerializedModelSaver_FileFormatsAvailable_Text_Sixth"), "", false));
}
}
/**
* The environment variables.
*/
protected transient Environment m_env;
/**
* Constructor.
*/
public SerializedModelSaver() {
useDefaultVisual();
setLayout(new BorderLayout());
add(m_visual, BorderLayout.CENTER);
m_fileFormat = s_fileFormatsAvailable.get(0);
m_env = Environment.getSystemWide();
}
/**
* Set a custom (descriptive) name for this bean
*
* @param name the name to use
*/
@Override
public void setCustomName(String name) {
m_visual.setText(name);
}
/**
* Get the custom (descriptive) name for this bean (if one has been set)
*
* @return the custom name (or the default name)
*/
@Override
public String getCustomName() {
return m_visual.getText();
}
/**
* Use the default images for this bean.
*
*/
@Override
public void useDefaultVisual() {
m_visual.loadIcons(BeanVisual.ICON_PATH + "SerializedModelSaver.gif",
BeanVisual.ICON_PATH + "SerializedModelSaver_animated.gif");
m_visual.setText("SerializedModelSaver");
}
/**
* Set the visual for this data source.
*
* @param newVisual a BeanVisual
value
*/
@Override
public void setVisual(BeanVisual newVisual) {
m_visual = newVisual;
}
/**
* Get the visual being used by this data source.
*
*/
@Override
public BeanVisual getVisual() {
return m_visual;
}
/**
* Returns true if, at this time, the object will accept a connection
* according to the supplied EventSetDescriptor.
*
* @param esd the EventSetDescriptor
* @return true if the object will accept a connection
*/
@Override
public boolean connectionAllowed(EventSetDescriptor esd) {
return connectionAllowed(esd.getName());
}
/**
* Returns true if, at this time, the object will accept a connection
* according to the supplied event name.
*
* @param eventName the event
* @return true if the object will accept a connection
*/
@Override
public boolean connectionAllowed(String eventName) {
return (m_listenee == null);
}
/**
* Notify this object that it has been registered as a listener with a source
* with respect to the supplied event name.
*
* @param eventName the event
* @param source the source with which this object has been registered as a
* listener
*/
@Override
public synchronized void connectionNotification(String eventName,
Object source) {
if (connectionAllowed(eventName)) {
m_listenee = source;
}
}
/**
* Notify this object that it has been deregistered as a listener with a
* source with respect to the supplied event name.
*
* @param eventName the event
* @param source the source with which this object has been registered as a
* listener
*/
@Override
public synchronized void disconnectionNotification(String eventName,
Object source) {
if (m_listenee == source) {
m_listenee = null;
}
}
/**
* Set a log for this bean.
*
* @param logger a weka.gui.Logger
value
*/
@Override
public void setLog(weka.gui.Logger logger) {
m_logger = logger;
}
/**
* Stop any processing that the bean might be doing.
*/
@Override
public void stop() {
// tell the listenee (upstream bean) to stop
if (m_listenee instanceof BeanCommon) {
((BeanCommon) m_listenee).stop();
}
}
/**
* Returns true if. at this time, the bean is busy with some (i.e. perhaps a
* worker thread is performing some calculation).
*
* @return true if the bean is busy.
*/
@Override
public boolean isBusy() {
return false;
}
/**
* makes sure that the filename is valid, i.e., replaces slashes, backslashes
* and colons with underscores ("_").
*
* @param filename the filename to cleanse
* @return the cleansed filename
*/
protected String sanitizeFilename(String filename) {
return filename.replaceAll("\\\\", "_").replaceAll(":", "_")
.replaceAll("/", "_");
}
/**
* Accept and save a batch trained clusterer.
*
* @param ce a BatchClassifierEvent
value
*/
@Override
public void acceptClusterer(BatchClustererEvent ce) {
if (ce.getTestSet() == null
|| ce.getTestOrTrain() == BatchClustererEvent.TEST
|| ce.getTestSet().isStructureOnly()) {
return;
}
Instances trainHeader = ce.getTestSet().getDataSet().stringFreeStructure();
String titleString = ce.getClusterer().getClass().getName();
titleString = titleString.substring(titleString.lastIndexOf('.') + 1,
titleString.length());
String prefix = "";
try {
prefix = m_env.substitute(m_filenamePrefix);
} catch (Exception ex) {
String message = Messages.getInstance().getString(
"SerializedModelSaver_AcceptClusterer_Message_Text_First")
+ statusMessagePrefix()
+ Messages.getInstance().getString(
"SerializedModelSaver_AcceptClusterer_Message_Text_Second")
+ ex.getMessage();
if (m_logger != null) {
m_logger.logMessage(message);
m_logger.statusMessage(statusMessagePrefix()
+ Messages.getInstance().getString(
"SerializedModelSaver_AcceptClusterer_StatusMessage_Text_First"));
} else {
System.err.println(message);
}
return;
}
String fileName = "" + prefix + titleString + "_" + ce.getSetNumber() + "_"
+ ce.getMaxSetNumber();
fileName = sanitizeFilename(fileName);
String dirName = m_directory.getPath();
try {
dirName = m_env.substitute(dirName);
} catch (Exception ex) {
String message = Messages.getInstance().getString(
"SerializedModelSaver_AcceptClusterer_Message_Text_Third")
+ statusMessagePrefix()
+ Messages.getInstance().getString(
"SerializedModelSaver_AcceptClusterer_Message_Text_Fourth")
+ ex.getMessage();
if (m_logger != null) {
m_logger.logMessage(message);
m_logger.statusMessage(statusMessagePrefix()
+ Messages.getInstance().getString(
"SerializedModelSaver_AcceptClusterer_StatusMessage_Text_Second"));
} else {
System.err.println(message);
}
return;
}
File tempFile = new File(dirName);
fileName = tempFile.getAbsolutePath() + File.separator + fileName;
saveModel(fileName, trainHeader, ce.getClusterer());
}
/**
* Accept and save an incrementally trained classifier.
*
* @param ce the BatchClassifierEvent containing the classifier
*/
@Override
public void acceptClassifier(final IncrementalClassifierEvent ce) {
if (ce.getStatus() == IncrementalClassifierEvent.BATCH_FINISHED) {
// Only save model when the end of the stream is reached
Instances header = ce.getStructure();
String titleString = ce.getClassifier().getClass().getName();
titleString = titleString.substring(titleString.lastIndexOf('.') + 1,
titleString.length());
String prefix = "";
try {
prefix = m_env.substitute(m_filenamePrefix);
} catch (Exception ex) {
String message = Messages.getInstance().getString(
"SerializedModelSaver_AcceptClassifier_Message_Text_First")
+ statusMessagePrefix()
+ Messages.getInstance().getString(
"SerializedModelSaver_AcceptClassifier_Message_Text_Second")
+ ex.getMessage();
if (m_logger != null) {
m_logger.logMessage(message);
m_logger
.statusMessage(statusMessagePrefix()
+ Messages
.getInstance()
.getString(
"SerializedModelSaver_AcceptClassifier_StatusMessage_Text_First"));
} else {
System.err.println(message);
}
return;
}
String fileName = "" + prefix + titleString;
fileName = sanitizeFilename(fileName);
String dirName = m_directory.getPath();
try {
dirName = m_env.substitute(dirName);
} catch (Exception ex) {
String message = Messages.getInstance().getString(
"SerializedModelSaver_AcceptClassifier_Message_Text_Third")
+ statusMessagePrefix()
+ Messages.getInstance().getString(
"SerializedModelSaver_AcceptClassifier_Message_Text_Fourth")
+ ex.getMessage();
if (m_logger != null) {
m_logger.logMessage(message);
m_logger
.statusMessage(statusMessagePrefix()
+ Messages
.getInstance()
.getString(
"SerializedModelSaver_AcceptClassifier_StatusMessage_Text_Second"));
} else {
System.err.println(message);
}
return;
}
File tempFile = new File(dirName);
fileName = tempFile.getAbsolutePath() + File.separator + fileName;
saveModel(fileName, header, ce.getClassifier());
}
}
/**
* Accept and save a batch trained classifier.
*
* @param ce the BatchClassifierEvent containing the classifier
*/
@Override
public void acceptClassifier(final BatchClassifierEvent ce) {
if (ce.getTrainSet() == null || ce.getTrainSet().isStructureOnly()) {
return;
}
Instances trainHeader = new Instances(ce.getTrainSet().getDataSet(), 0);
String titleString = ce.getClassifier().getClass().getName();
titleString = titleString.substring(titleString.lastIndexOf('.') + 1,
titleString.length());
String prefix = "";
try {
prefix = m_env.substitute(m_filenamePrefix);
} catch (Exception ex) {
String message = Messages.getInstance().getString(
"SerializedModelSaver_AcceptClassifier_Message_Text_Fifth")
+ statusMessagePrefix()
+ Messages.getInstance().getString(
"SerializedModelSaver_AcceptClassifier_Message_Text_Sixth")
+ ex.getMessage();
if (m_logger != null) {
m_logger.logMessage(message);
m_logger.statusMessage(statusMessagePrefix()
+ Messages.getInstance().getString(
"SerializedModelSaver_AcceptClassifier_StatusMessage_Text_Third"));
} else {
System.err.println(message);
}
return;
}
String fileName = "" + prefix + titleString + "_" + ce.getSetNumber() + "_"
+ ce.getMaxSetNumber();
fileName = sanitizeFilename(fileName);
String dirName = m_directory.getPath();
try {
dirName = m_env.substitute(dirName);
} catch (Exception ex) {
String message = Messages.getInstance().getString(
"SerializedModelSaver_AcceptClassifier_Message_Text_Seventh")
+ statusMessagePrefix()
+ Messages.getInstance().getString(
"SerializedModelSaver_AcceptClassifier_Message_Text_Eighth")
+ ex.getMessage();
if (m_logger != null) {
m_logger.logMessage(message);
m_logger.statusMessage(statusMessagePrefix()
+ Messages.getInstance().getString(
"SerializedModelSaver_AcceptClassifier_StatusMessage_Text_Fourth"));
} else {
System.err.println(message);
}
return;
}
File tempFile = new File(dirName);
fileName = tempFile.getAbsolutePath() + File.separator + fileName;
saveModel(fileName, trainHeader, ce.getClassifier());
}
/**
* Helper routine to actually save the models.
*/
private void saveModel(String fileName, Instances trainHeader, Object model) {
m_fileFormat = validateFileFormat(m_fileFormat);
if (m_fileFormat == null) {
// default to binary if validation fails
m_fileFormat = s_fileFormatsAvailable.get(0);
}
try {
switch (m_fileFormat.getID()) {
case KOMLV:
fileName = fileName + KOML.FILE_EXTENSION + FILE_EXTENSION;
saveKOML(new File(fileName), model, trainHeader);
break;
case XSTREAM:
fileName = fileName + XStream.FILE_EXTENSION + FILE_EXTENSION;
saveXStream(new File(fileName), model, trainHeader);
break;
default:
fileName = fileName + "." + FILE_EXTENSION;
saveBinary(new File(fileName), model, trainHeader);
break;
}
} catch (Exception ex) {
System.err.println(Messages.getInstance().getString(
"SerializedModelSaver_SaveModel_Error_Text"));
if (m_logger != null) {
m_logger.logMessage(Messages.getInstance().getString(
"SerializedModelSaver_SaveModel_LogMessage_Text_First")
+ statusMessagePrefix()
+ Messages.getInstance().getString(
"SerializedModelSaver_SaveModel_LogMessage_Text_Second")
+ " : "
+ ex.getMessage());
m_logger.statusMessage(statusMessagePrefix()
+ Messages.getInstance().getString(
"SerializedModelSaver_SaveModel_StatusMessage_Text"));
}
}
}
/**
* Save a model in binary form.
*
* @param saveTo the file name to save to
* @param model the model to save
* @param header the header of the data that was used to train the model
* (optional)
*/
public static void saveBinary(File saveTo, Object model, Instances header)
throws IOException {
ObjectOutputStream os = new ObjectOutputStream(new BufferedOutputStream(
new FileOutputStream(saveTo)));
os.writeObject(model);
// now the header
if (header != null) {
os.writeObject(header);
}
os.close();
}
/**
* Save a model in KOML deep object serialized XML form.
*
* @param saveTo the file name to save to
* @param model the model to save
* @param header the header of the data that was used to train the model
* (optional)
*/
public static void saveKOML(File saveTo, Object model, Instances header)
throws Exception {
Vector v = new Vector();
v.add(model);
if (header != null) {
v.add(header);
}
v.trimToSize();
KOML.write(saveTo.getAbsolutePath(), v);
}
/**
* Save a model in XStream deep object serialized XML form.
*
* @param saveTo the file name to save to
* @param model the model to save
* @param header the header of the data that was used to train the model
* (optional)
*/
public static void saveXStream(File saveTo, Object model, Instances header)
throws Exception {
Vector v = new Vector();
v.add(model);
if (header != null) {
v.add(header);
}
v.trimToSize();
XStream.write(saveTo.getAbsolutePath(), v);
}
/**
* Get the directory that the model(s) will be saved into
*
* @return the directory to save to
*/
public File getDirectory() {
return m_directory;
}
/**
* Set the directory that the model(s) will be saved into.
*
* @param d the directory to save to
*/
public void setDirectory(File d) {
m_directory = d;
if (m_useRelativePath) {
try {
m_directory = Utils.convertToRelativePath(m_directory);
} catch (Exception ex) {
}
}
}
/**
* Set whether to use relative paths for the directory. I.e. relative to the
* startup (user.dir) directory
*
* @param rp true if relative paths are to be used
*/
public void setUseRelativePath(boolean rp) {
m_useRelativePath = rp;
}
/**
* Get whether to use relative paths for the directory. I.e. relative to the
* startup (user.dir) directory
*
* @return true if relative paths are to be used
*/
public boolean getUseRelativePath() {
return m_useRelativePath;
}
/**
* Get the prefix to prepend to the model file names.
*
* @return the prefix to prepend
*/
public String getPrefix() {
return m_filenamePrefix;
}
/**
* Set the prefix to prepend to the model file names.
*
* @param p the prefix to prepend
*/
public void setPrefix(String p) {
m_filenamePrefix = p;
}
/**
* Global info for this bean. Gets displayed in the GUI.
*
* @return information about this bean.
*/
public String globalInfo() {
return Messages.getInstance().getString(
"SerializedModelSaver_GlobalInfo_Text");
}
/**
* Set the file format to use for saving.
*
* @param ff the file format to use
*/
public void setFileFormat(Tag ff) {
m_fileFormat = ff;
}
/**
* Get the file format to use for saving.
*
* @return the file format to use
*/
public Tag getFileFormat() {
return m_fileFormat;
}
/**
* Validate the file format. After this bean is deserialized, classes for XML
* serialization may not be in the classpath any more.
*
* @param ff the current file format to validate
*/
public Tag validateFileFormat(Tag ff) {
Tag r = ff;
if (ff.getID() == BINARY) {
return ff;
}
if (ff.getID() == KOMLV && !KOML.isPresent()) {
r = null;
}
if (ff.getID() == XSTREAM && !XStream.isPresent()) {
r = null;
}
return r;
}
private String statusMessagePrefix() {
return getCustomName() + "$" + hashCode() + "|";
}
/**
* Set environment variables to use.
*
* @param env the environment variables to use
*/
@Override
public void setEnvironment(Environment env) {
m_env = env;
}
// Custom de-serialization in order to set default
// environment variables on de-serialization
private void readObject(ObjectInputStream aStream) throws IOException,
ClassNotFoundException {
aStream.defaultReadObject();
// set a default environment to use
m_env = Environment.getSystemWide();
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy