All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetradapp.model.SemEstimatorWrapper Maven / Gradle / Ivy

There is a newer version: 7.6.6
Show newest version
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below.       //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006,       //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard        //
// Scheines, Joseph Ramsey, and Clark Glymour.                               //
//                                                                           //
// This program is free software; you can redistribute it and/or modify      //
// it under the terms of the GNU General Public License as published by      //
// the Free Software Foundation; either version 2 of the License, or         //
// (at your option) any later version.                                       //
//                                                                           //
// This program is distributed in the hope that it will be useful,           //
// but WITHOUT ANY WARRANTY; without even the implied warranty of            //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the             //
// GNU General Public License for more details.                              //
//                                                                           //
// You should have received a copy of the GNU General Public License         //
// along with this program; if not, write to the Free Software               //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA //
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetradapp.model;

import edu.cmu.tetrad.data.*;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.sem.*;
import edu.cmu.tetrad.session.SessionModel;
import edu.cmu.tetrad.util.*;

import javax.swing.*;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serial;
import java.util.LinkedList;
import java.util.List;

/**
 * Wraps a SemEstimator for use in the Tetrad application.
 *
 * @author josephramsey
 */
public class SemEstimatorWrapper implements SessionModel {
    @Serial
    private static final long serialVersionUID = 23L;
    private final Parameters params;
    private final SemPm semPm;
    /**
     * @serial Can be null.
     */
    private String name;
    /**
     * @serial Cannot be null.
     */
    private SemEstimator semEstimator;

    //==============================CONSTRUCTORS==========================//

    /**
     * Private constructor for serialization only. Problem is, for the real constructors, I'd like to call the degrees
     * of freedom check, which pops up a dialog. This is irritating when running unit tests. jdramsey 8/29/07
     */
    public SemEstimatorWrapper(DataModel dataModel, SemPm semPm, Parameters params) {
        this.params = params;
        this.semPm = semPm;

        if (dataModel instanceof DataSet dataSet) {
            SemEstimator estimator = new SemEstimator(dataSet, semPm, getOptimizer());
            estimator.setNumRestarts(getParams().getInt("numRestarts", 1));
            estimator.setScoreType((ScoreType) getParams().get("scoreType", ScoreType.Fgls));
            estimator.estimate();
            if (!degreesOfFreedomCheck(semPm)) {
                throw new IllegalArgumentException("Cannot proceed.");
            }
            this.semEstimator = estimator;
        } else if (dataModel instanceof ICovarianceMatrix) {
            ICovarianceMatrix covMatrix = new CovarianceMatrix((ICovarianceMatrix) dataModel);
            SemEstimator estimator = new SemEstimator(covMatrix, semPm, getOptimizer());
            estimator.setNumRestarts(getParams().getInt("numRestarts", 1));
            estimator.setScoreType((ScoreType) getParams().get("scoreType", ScoreType.Fml));
            estimator.estimate();
            if (!degreesOfFreedomCheck(semPm)) {
                throw new IllegalArgumentException("Cannot proceed.");
            }
            this.semEstimator = estimator;
        } else {
            throw new IllegalArgumentException("Data must consist of continuous data sets or covariance matrices.");
        }
    }

    public SemEstimatorWrapper(DataWrapper dataWrapper,
                               SemPmWrapper semPmWrapper, Parameters params) {
        this(dataWrapper.getSelectedDataModel(), semPmWrapper.getSemPm(), params);
        log();
    }

    /**
     * Generates a simple exemplar of this class to test serialization.
     *
     * @see TetradSerializableUtils
     */
    public static SemEstimatorWrapper serializableInstance() {
        List variables = new LinkedList<>();
        ContinuousVariable x = new ContinuousVariable("X");
        variables.add(x);
        DataSet dataSet = new BoxDataSet(new VerticalDoubleDataBox(10, variables.size()), variables);

        for (int i = 0; i < dataSet.getNumRows(); i++) {
            for (int j = 0; j < dataSet.getNumColumns(); j++) {
                dataSet.setDouble(i, j, RandomUtil.getInstance().nextDouble());
            }
        }
        Dag dag = new Dag();
        dag.addNode(x);
        SemPm pm = new SemPm(dag);
        Parameters params1 = new Parameters();
        return new SemEstimatorWrapper(dataSet, pm, params1);
    }

    private static boolean containsCovarParam(SemPm semPm) {
        boolean containsCovarParam = false;
        List params = semPm.getParameters();

        for (Parameter param : params) {
            if (param.getType() == ParamType.COVAR) {
                containsCovarParam = true;
                break;
            }
        }
        return containsCovarParam;
    }

    private boolean degreesOfFreedomCheck(SemPm semPm) {
        if (semPm.getDof() < 1) {
            int ret = JOptionPane.showConfirmDialog(JOptionUtils.centeringComp(),
                    "This model has non-positive degrees of freedom (DOF = "
                            + semPm.getDof() + "). "
                            + "\nEstimation will be uninformative. Are you sure you want to proceed?",
                    "Please confirm", JOptionPane.YES_NO_OPTION, JOptionPane.WARNING_MESSAGE);

            return ret == JOptionPane.YES_OPTION;
        }

        return true;
    }

    //============================PUBLIC METHODS=========================//
    public SemEstimator getSemEstimator() {
        return this.semEstimator;
    }

    public void setSemEstimator(SemEstimator semEstimator) {
        this.semEstimator = semEstimator;
    }

    public SemIm getEstimatedSemIm() {
        return this.semEstimator.getEstimatedSem();
    }

    public String getSemOptimizerType() {
        return getParams().getString("semOptimizerType", "Regression");
    }

    public void setSemOptimizerType(String type) {
        getParams().set("semOptimizerType", type);
    }

    public Graph getGraph() {
        return this.semEstimator.getEstimatedSem().getSemPm().getGraph();
    }

    public String getName() {
        return this.name;
    }

    public void setName(String name) {
        this.name = name;
    }

    //=============================== Private methods =======================//
    private void log() {
        TetradLogger.getInstance().log("info", "SEM Estimator:");
        TetradLogger.getInstance().log("im", "" + getEstimatedSemIm());
        TetradLogger.getInstance().log("details", "ChiSq = " + getEstimatedSemIm().getChiSquare());
        TetradLogger.getInstance().log("details", "DOF = " + getEstimatedSemIm().getSemPm().getDof());
        TetradLogger.getInstance().log("details", "P = " + getEstimatedSemIm().getPValue());
    }

    /**
     * Adds semantic checks to the default deserialization method. This method must have the standard signature for a
     * readObject method, and the body of the method must begin with "s.defaultReadObject();". Other than that, any
     * semantic checks can be specified and do not need to stay the same from version to version. A readObject method of
     * this form may be added to any class, even if Tetrad sessions were previously saved out using a version of the
     * class that didn't include it. (That's what the "s.defaultReadObject();" is for). See J. Bloch, Effective Java, for
     * help.
     */
    @Serial
    private void readObject(ObjectInputStream s)
            throws IOException, ClassNotFoundException {
        s.defaultReadObject();
    }

    public Parameters getParams() {
        return this.params;
    }

    private SemOptimizer getOptimizer() {
        SemOptimizer optimizer;
        String type = getParams().getString("semOptimizerType", "Regression");

        if ("Regression".equals(type)) {
            SemOptimizer defaultOptimization = getDefaultOptimization();

            if (!(defaultOptimization instanceof SemOptimizerRegression)) {
                optimizer = defaultOptimization;
                type = getType(defaultOptimization);
                getParams().set("semOptimizerType", type);
            } else {
                optimizer = new SemOptimizerRegression();
            }
        } else if ("EM".equals(type)) {
            optimizer = new SemOptimizerEm();
        } else if ("Powell".equals(type)) {
            optimizer = new SemOptimizerPowell();
        } else if ("Random Search".equals(type)) {
            optimizer = new SemOptimizerScattershot();
        } else if ("RICF".equals(type)) {
            optimizer = new SemOptimizerRicf();
        } else {
            if (this.semPm != null) {
                optimizer = getDefaultOptimization();

                String _type = getType(optimizer);

                if (_type != null) {
                    getParams().set("semOptimizerType", _type);
                }
            } else {
                optimizer = null;
            }
        }

        return optimizer;
    }

    private String getType(SemOptimizer optimizer) {
        String _type = null;

        if (optimizer instanceof SemOptimizerRegression) {
            _type = "Regression";
        } else if (optimizer instanceof SemOptimizerEm) {
            _type = "EM";
        } else if (optimizer instanceof SemOptimizerPowell) {
            _type = "Powell";
        } else if (optimizer instanceof SemOptimizerScattershot) {
            _type = "Random Search";
        } else if (optimizer instanceof SemOptimizerRicf) {
            _type = "RICF";
        }

        return _type;
    }

    private boolean containsFixedParam(SemPm semPm) {
        return new SemIm(semPm).getNumFixedParams() > 0;
    }

    public ScoreType getScoreType() {
        return (ScoreType) this.params.get("scoreType", ScoreType.SemBic);
    }

    public void setScoreType(ScoreType scoreType) {
        this.params.set("scoreType", scoreType);
    }

    public int getNumRestarts() {
        return getParams().getInt("numRestarts", 1);
    }

    public void setNumRestarts(int numRestarts) {
        getParams().set("numRestarts", numRestarts);
    }

    private SemOptimizer getDefaultOptimization() {
        if (this.semPm == null) {
            throw new NullPointerException(
                    "Sorry, I didn't see a SEM PM as parent to the estimator; perhaps the parents are wrong.");
        }

        boolean containsLatent = false;

        for (Node node : this.semPm.getGraph().getNodes()) {
            if (node.getNodeType() == NodeType.LATENT) {
                containsLatent = true;
                break;
            }
        }

        SemOptimizer optimizer;

        if (containsFixedParam(this.semPm) || this.semPm.getGraph().paths().existsDirectedCycle()
                || SemEstimatorWrapper.containsCovarParam(this.semPm)) {
            optimizer = new SemOptimizerPowell();
        } else if (containsLatent) {
            optimizer = new SemOptimizerEm();
        } else {
            optimizer = new SemOptimizerRegression();
        }

        optimizer.setNumRestarts(getParams().getInt("numRestarts", 1));

        return optimizer;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy