All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetradapp.model.SemEstimatorWrapper Maven / Gradle / Ivy

The newest version!
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below.       //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006,       //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard        //
// Scheines, Joseph Ramsey, and Clark Glymour.                               //
//                                                                           //
// This program is free software; you can redistribute it and/or modify      //
// it under the terms of the GNU General Public License as published by      //
// the Free Software Foundation; either version 2 of the License, or         //
// (at your option) any later version.                                       //
//                                                                           //
// This program is distributed in the hope that it will be useful,           //
// but WITHOUT ANY WARRANTY; without even the implied warranty of            //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the             //
// GNU General Public License for more details.                              //
//                                                                           //
// You should have received a copy of the GNU General Public License         //
// along with this program; if not, write to the Free Software               //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA //
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetradapp.model;

import edu.cmu.tetrad.data.*;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.sem.*;
import edu.cmu.tetrad.util.*;
import edu.cmu.tetradapp.session.SessionModel;

import javax.swing.*;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serial;
import java.util.LinkedList;
import java.util.List;

/**
 * Wraps a SemEstimator for use in the Tetrad application.
 *
 * @author josephramsey
 * @version $Id: $Id
 */
public class SemEstimatorWrapper implements SessionModel {
    @Serial
    private static final long serialVersionUID = 23L;

    /**
     * The parameters for the estimator.
     */
    private final Parameters params;

    /**
     * The SEM PM for the estimator.
     */
    private final SemPm semPm;

    /**
     * The name of the estimator.
     */
    private String name;
    /**
     * The estimator itself.
     */
    private SemEstimator semEstimator;

    //==============================CONSTRUCTORS==========================//

    /**
     * Private constructor for serialization only. Problem is, for the real constructors, I'd like to call the degrees
     * of freedom check, which pops up a dialog. This is irritating when running unit tests. jdramsey 8/29/07
     *
     * @param dataModel a {@link edu.cmu.tetrad.data.DataModel} object
     * @param semPm     a {@link edu.cmu.tetrad.sem.SemPm} object
     * @param params    a {@link edu.cmu.tetrad.util.Parameters} object
     */
    public SemEstimatorWrapper(DataModel dataModel, SemPm semPm, Parameters params) {
        this.params = params;
        this.semPm = semPm;

        if (dataModel instanceof DataSet dataSet) {
            SemEstimator estimator = new SemEstimator(dataSet, semPm, getOptimizer());
            estimator.setNumRestarts(getParams().getInt("numRestarts", 1));
            estimator.setScoreType((ScoreType) getParams().get("scoreType", ScoreType.Fgls));
            estimator.estimate();
            if (!degreesOfFreedomCheck(semPm)) {
                throw new IllegalArgumentException("Cannot proceed.");
            }
            this.semEstimator = estimator;
        } else if (dataModel instanceof ICovarianceMatrix) {
            ICovarianceMatrix covMatrix = new CovarianceMatrix((ICovarianceMatrix) dataModel);
            SemEstimator estimator = new SemEstimator(covMatrix, semPm, getOptimizer());
            estimator.setNumRestarts(getParams().getInt("numRestarts", 1));
            estimator.setScoreType((ScoreType) getParams().get("scoreType", ScoreType.Fml));
            estimator.estimate();
            if (!degreesOfFreedomCheck(semPm)) {
                throw new IllegalArgumentException("Cannot proceed.");
            }
            this.semEstimator = estimator;
        } else {
            throw new IllegalArgumentException("Data must consist of continuous data sets or covariance matrices.");
        }
    }

    /**
     * 

Constructor for SemEstimatorWrapper.

* * @param dataWrapper a {@link edu.cmu.tetradapp.model.DataWrapper} object * @param semPmWrapper a {@link edu.cmu.tetradapp.model.SemPmWrapper} object * @param params a {@link edu.cmu.tetrad.util.Parameters} object */ public SemEstimatorWrapper(DataWrapper dataWrapper, SemPmWrapper semPmWrapper, Parameters params) { this(dataWrapper.getSelectedDataModel(), semPmWrapper.getSemPm(), params); log(); } /** * Constructs a SemEstimatorWrapper object. * * @param simulation a Simulation object * @param semPmWrapper a SemPmWrapper object * @param parameters a Parameters object */ public SemEstimatorWrapper(Simulation simulation, SemPmWrapper semPmWrapper, Parameters parameters) { this(new DataWrapper(simulation, parameters), semPmWrapper, parameters); } /** * Generates a simple exemplar of this class to test serialization. * * @return a {@link edu.cmu.tetradapp.model.SemEstimatorWrapper} object * @see TetradSerializableUtils */ public static SemEstimatorWrapper serializableInstance() { List variables = new LinkedList<>(); ContinuousVariable x = new ContinuousVariable("X"); variables.add(x); DataSet dataSet = new BoxDataSet(new VerticalDoubleDataBox(10, variables.size()), variables); for (int i = 0; i < dataSet.getNumRows(); i++) { for (int j = 0; j < dataSet.getNumColumns(); j++) { dataSet.setDouble(i, j, RandomUtil.getInstance().nextDouble()); } } Dag dag = new Dag(); dag.addNode(x); SemPm pm = new SemPm(dag); Parameters params1 = new Parameters(); return new SemEstimatorWrapper(dataSet, pm, params1); } private static boolean containsCovarParam(SemPm semPm) { boolean containsCovarParam = false; List params = semPm.getParameters(); for (Parameter param : params) { if (param.getType() == ParamType.COVAR) { containsCovarParam = true; break; } } return containsCovarParam; } private boolean degreesOfFreedomCheck(SemPm semPm) { if (semPm.getDof() < 1) { int ret = JOptionPane.showConfirmDialog(JOptionUtils.centeringComp(), "This model has non-positive degrees of freedom (DOF = " + semPm.getDof() + "). " + "\nEstimation will be uninformative. Are you sure you want to proceed?", "Please confirm", JOptionPane.YES_NO_OPTION, JOptionPane.WARNING_MESSAGE); return ret == JOptionPane.YES_OPTION; } return true; } //============================PUBLIC METHODS=========================// /** *

Getter for the field semEstimator.

* * @return a {@link edu.cmu.tetrad.sem.SemEstimator} object */ public SemEstimator getSemEstimator() { return this.semEstimator; } /** *

Setter for the field semEstimator.

* * @param semEstimator a {@link edu.cmu.tetrad.sem.SemEstimator} object */ public void setSemEstimator(SemEstimator semEstimator) { this.semEstimator = semEstimator; } /** *

getEstimatedSemIm.

* * @return a {@link edu.cmu.tetrad.sem.SemIm} object */ public SemIm getEstimatedSemIm() { return this.semEstimator.getEstimatedSem(); } /** *

getSemOptimizerType.

* * @return a {@link java.lang.String} object */ public String getSemOptimizerType() { return getParams().getString("semOptimizerType", "Regression"); } /** *

setSemOptimizerType.

* * @param type a {@link java.lang.String} object */ public void setSemOptimizerType(String type) { getParams().set("semOptimizerType", type); } /** *

getGraph.

* * @return a {@link edu.cmu.tetrad.graph.Graph} object */ public Graph getGraph() { return this.semEstimator.getEstimatedSem().getSemPm().getGraph(); } /** *

Getter for the field name.

* * @return a {@link java.lang.String} object */ public String getName() { return this.name; } /** * {@inheritDoc} */ public void setName(String name) { this.name = name; } //=============================== Private methods =======================// private void log() { TetradLogger.getInstance().log("SEM Estimator:"); String message3 = "" + getEstimatedSemIm(); TetradLogger.getInstance().log(message3); String message2 = "ChiSq = " + getEstimatedSemIm().getChiSquare(); TetradLogger.getInstance().log(message2); String message1 = "DOF = " + getEstimatedSemIm().getSemPm().getDof(); TetradLogger.getInstance().log(message1); String message = "P = " + getEstimatedSemIm().getPValue(); TetradLogger.getInstance().log(message); } /** * Writes the object to the specified ObjectOutputStream. * * @param out The ObjectOutputStream to write the object to. * @throws IOException If an I/O error occurs. */ @Serial private void writeObject(ObjectOutputStream out) throws IOException { try { out.defaultWriteObject(); } catch (IOException e) { TetradLogger.getInstance().log("Failed to serialize object: " + getClass().getCanonicalName() + ", " + e.getMessage()); throw e; } } /** * Reads the object from the specified ObjectInputStream. This method is used during deserialization * to restore the state of the object. * * @param in The ObjectInputStream to read the object from. * @throws IOException If an I/O error occurs. * @throws ClassNotFoundException If the class of the serialized object cannot be found. */ @Serial private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { try { in.defaultReadObject(); } catch (IOException e) { TetradLogger.getInstance().log("Failed to deserialize object: " + getClass().getCanonicalName() + ", " + e.getMessage()); throw e; } } /** *

Getter for the field params.

* * @return a {@link edu.cmu.tetrad.util.Parameters} object */ public Parameters getParams() { return this.params; } private SemOptimizer getOptimizer() { SemOptimizer optimizer; String type = getParams().getString("semOptimizerType", "Regression"); if ("Regression".equals(type)) { SemOptimizer defaultOptimization = getDefaultOptimization(); if (!(defaultOptimization instanceof SemOptimizerRegression)) { optimizer = defaultOptimization; type = getType(defaultOptimization); getParams().set("semOptimizerType", type); } else { optimizer = new SemOptimizerRegression(); } } else if ("EM".equals(type)) { optimizer = new SemOptimizerEm(); } else if ("Powell".equals(type)) { optimizer = new SemOptimizerPowell(); } else if ("Random Search".equals(type)) { optimizer = new SemOptimizerScattershot(); } else if ("RICF".equals(type)) { optimizer = new SemOptimizerRicf(); } else { if (this.semPm != null) { optimizer = getDefaultOptimization(); String _type = getType(optimizer); if (_type != null) { getParams().set("semOptimizerType", _type); } } else { optimizer = null; } } return optimizer; } private String getType(SemOptimizer optimizer) { String _type = null; if (optimizer instanceof SemOptimizerRegression) { _type = "Regression"; } else if (optimizer instanceof SemOptimizerEm) { _type = "EM"; } else if (optimizer instanceof SemOptimizerPowell) { _type = "Powell"; } else if (optimizer instanceof SemOptimizerScattershot) { _type = "Random Search"; } else if (optimizer instanceof SemOptimizerRicf) { _type = "RICF"; } return _type; } private boolean containsFixedParam(SemPm semPm) { return new SemIm(semPm).getNumFixedParams() > 0; } /** *

getScoreType.

* * @return a {@link edu.cmu.tetrad.sem.ScoreType} object */ public ScoreType getScoreType() { return (ScoreType) this.params.get("scoreType", ScoreType.SemBic); } /** *

setScoreType.

* * @param scoreType a {@link edu.cmu.tetrad.sem.ScoreType} object */ public void setScoreType(ScoreType scoreType) { this.params.set("scoreType", scoreType); } /** *

getNumRestarts.

* * @return a int */ public int getNumRestarts() { return getParams().getInt("numRestarts", 1); } /** *

setNumRestarts.

* * @param numRestarts a int */ public void setNumRestarts(int numRestarts) { getParams().set("numRestarts", numRestarts); } private SemOptimizer getDefaultOptimization() { if (this.semPm == null) { throw new NullPointerException("Sorry, I didn't see a SEM PM as parent to the estimator; perhaps the parents are wrong."); } boolean containsLatent = false; for (Node node : this.semPm.getGraph().getNodes()) { if (node.getNodeType() == NodeType.LATENT) { containsLatent = true; break; } } SemOptimizer optimizer; if (containsFixedParam(this.semPm) || this.semPm.getGraph().paths().existsDirectedCycle() || SemEstimatorWrapper.containsCovarParam(this.semPm)) { optimizer = new SemOptimizerPowell(); } else if (containsLatent) { optimizer = new SemOptimizerEm(); } else { optimizer = new SemOptimizerRegression(); } optimizer.setNumRestarts(getParams().getInt("numRestarts", 1)); return optimizer; } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy