All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetradapp.model.BuildPureClustersRunner Maven / Gradle / Ivy

There is a newer version: 7.6.6
Show newest version
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below.       //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006,       //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard        //
// Scheines, Joseph Ramsey, and Clark Glymour.                               //
//                                                                           //
// This program is free software; you can redistribute it and/or modify      //
// it under the terms of the GNU General Public License as published by      //
// the Free Software Foundation; either version 2 of the License, or         //
// (at your option) any later version.                                       //
//                                                                           //
// This program is distributed in the hope that it will be useful,           //
// but WITHOUT ANY WARRANTY; without even the implied warranty of            //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the             //
// GNU General Public License for more details.                              //
//                                                                           //
// You should have received a copy of the GNU General Public License         //
// along with this program; if not, write to the Free Software               //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA //
///////////////////////////////////////////////////////////////////////////////

package edu.cmu.tetradapp.model;

import edu.cmu.tetrad.data.*;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.LayoutUtil;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.search.Bpc;
import edu.cmu.tetrad.search.utils.BpcAlgorithmType;
import edu.cmu.tetrad.search.utils.BpcTestType;
import edu.cmu.tetrad.search.utils.ClusterUtils;
import edu.cmu.tetrad.search.utils.MimUtils;
import edu.cmu.tetrad.search.work_in_progress.BpcTetradPurifyWashdown;
import edu.cmu.tetrad.search.work_in_progress.Washdown;
import edu.cmu.tetrad.sem.ReidentifyVariables;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.TetradSerializableUtils;
import edu.cmu.tetrad.util.Unmarshallable;

import java.rmi.MarshalledObject;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;

/**
 * Extends AbstractAlgorithmRunner to produce a wrapper for the BuildPureClusters algorithm.
 *
 * @author Ricardo Silva
 */
public class BuildPureClustersRunner extends AbstractMimRunner
        implements GraphSource, Unmarshallable {
    private static final long serialVersionUID = 23L;

    /**
     * To reidentify variables.
     */
    private SemIm semIm;
    private Graph trueGraph;

    //============================CONSTRUCTORS============================//

    /**
     * Constructs a wrapper for the given DataWrapper.
     */

    public BuildPureClustersRunner(DataWrapper dataWrapper,
                                   Parameters pureClustersParams) {
        super(dataWrapper, (Clusters) pureClustersParams.get("clusters", null), pureClustersParams);

    }

    public BuildPureClustersRunner(DataWrapper dataWrapper, SemImWrapper semImWrapper,
                                   Parameters pureClustersParams) {
        super(dataWrapper, (Clusters) pureClustersParams.get("clusters", null), pureClustersParams);
        this.semIm = semImWrapper.getSemIm();
        this.trueGraph = this.semIm.getSemPm().getGraph();
    }

    public BuildPureClustersRunner(DataWrapper dataWrapper, GraphWrapper graphWrapper,
                                   Parameters pureClustersParams) {
        super(dataWrapper, (Clusters) pureClustersParams.get("clusters", null), pureClustersParams);
        this.trueGraph = graphWrapper.getGraph();
    }

    /**
     * Generates a simple exemplar of this class to test serialization.
     *
     * @see TetradSerializableUtils
     */
    public static PcRunner serializableInstance() {
        return PcRunner.serializableInstance();
    }

    //===================PUBLIC METHODS OVERRIDING ABSTRACT================//

    /**
     * Executes the algorithm, producing (at least) a result workbench. Must be implemented in the extending class.
     */
    public void execute() {
        boolean rKey = getParams().getBoolean("BPCrDown", false);

        BpcAlgorithmType algorithm = (BpcAlgorithmType) getParams().get("bpcAlgorithmthmType", BpcAlgorithmType.FIND_ONE_FACTOR_CLUSTERS);

        Graph searchGraph;

        if (rKey) {
            Washdown washdown;
            Object source = getData();

            if (source instanceof DataSet) {
                washdown = new Washdown((DataSet) source, getParams().getDouble("alpha", 0.001));
            } else {
                washdown = new Washdown((CovarianceMatrix) source, getParams().getDouble("alpha", 0.001));
            }

            searchGraph = washdown.search();
        } else {
            BpcTestType tetradTestType = (BpcTestType) getParams().get("tetradTestType", BpcTestType.TETRAD_WISHART);

            if (algorithm == BpcAlgorithmType.TETRAD_PURIFY_WASHDOWN) {
                BpcTetradPurifyWashdown bpc;
                Object source = getData();

                if (source instanceof DataSet) {
                    bpc = new BpcTetradPurifyWashdown(
                            (DataSet) source,
                            tetradTestType,
                            getParams().getDouble("alpha", 0.001));

                } else {
                    bpc = new BpcTetradPurifyWashdown((ICovarianceMatrix) source,
                            tetradTestType, getParams().getDouble("alpha", 0.001));

                }

                searchGraph = bpc.search();
            } else if (algorithm == BpcAlgorithmType.BUILD_PURE_CLUSTERS) {
                Bpc bpc;
                DataModel source = getData();

                BpcTestType testType = (BpcTestType) getParams().get("tetradTestType", BpcTestType.TETRAD_WISHART);

                if (source instanceof ICovarianceMatrix) {
                    bpc = new Bpc((ICovarianceMatrix) source,
                            getParams().getDouble("alpha", 0.001),
                            testType
                    );
                } else if (source instanceof DataSet) {
                    bpc = new Bpc(
                            (DataSet) source, getParams().getDouble("alpha", 0.001),
                            testType
                    );
                } else {
                    throw new IllegalArgumentException();
                }

                searchGraph = bpc.search();

            }
//            else if (algorithm == BpcAlgorithmType.FIND_ONE_FACTOR_CLUSTERS) {
////                FindOneFactorClusters bpc;
////                Object source = getContinuousData();
////
////                if (source instanceof DataSet) {
////                    bpc = new FindOneFactorClusters(
////                            (DataSet) source,
////                            tetradTestType,
////                            getParameters().getAlternativePenalty());
////                } else {
////                    bpc = new FindOneFactorClusters((ICovarianceMatrix) source,
////                            tetradTestType, getParameters().getAlternativePenalty());
////                }
////
////                searchGraph = bpc.search();
//
//                FindOneFactorClusters2 bpc;
//                Object source = getContinuousData();
//                FindOneFactorClusters2.Algorithm sag = FindOneFactorClusters2.Algorithm.SAG;
//
//                if (source instanceof DataSet) {
//                    bpc = new FindOneFactorClusters2(
//                            (DataSet) source,
//                            tetradTestType, sag,
//                            getParameters().getAlternativePenalty());
//
////                    bpc = new FindTwoFactorClusters4(
////                            (DataSet) source,
////                            getParameters().getAlternativePenalty());
//                } else {
//                    bpc = new FindOneFactorClusters2((ICovarianceMatrix) source,
//                            tetradTestType, sag, getParameters().getAlternativePenalty());
////
////                    bpc = new FindTwoFactorClusters4((ICovarianceMatrix) source,
////                            getParameters().getAlternativePenalty());
//                }
//
//                searchGraph = bpc.search();
//
//            }
//            else if (algorithm == BpcAlgorithmType.FIND_TWO_FACTOR_CLUSTERS) {
//                FindTwoFactorClusters2 bpc;
//                Object source = getContinuousData();
//
//                if (source instanceof DataSet) {
//                    bpc = new FindTwoFactorClusters2(
//                            (DataSet) source,
//                            tetradTestType,
//                            getParameters().getAlternativePenalty());
//
////                    bpc = new FindTwoFactorClusters4(
////                            (DataSet) source,
////                            getParameters().getAlternativePenalty());
//                } else {
//                    bpc = new FindTwoFactorClusters2((ICovarianceMatrix) source,
//                            tetradTestType, getParameters().getAlternativePenalty());
////
////                    bpc = new FindTwoFactorClusters4((ICovarianceMatrix) source,
////                            getParameters().getAlternativePenalty());
//                }
//
//                searchGraph = bpc.search();
//            }
            else {
                throw new IllegalStateException();
            }
        }

        if (this.semIm != null) {
            List> partition = MimUtils.convertToClusters2(searchGraph);

            List variableNames = ReidentifyVariables.reidentifyVariables2(partition, this.trueGraph, (DataSet) getData());
            rename(searchGraph, partition, variableNames);
//            searchGraph = reidentifyVariables2(searchGraph, semIm);
        } else if (this.trueGraph != null) {
            List> partition = MimUtils.convertToClusters2(searchGraph);
            List variableNames = ReidentifyVariables.reidentifyVariables1(partition, this.trueGraph);
            rename(searchGraph, partition, variableNames);
//            searchGraph = reidentifyVariables(searchGraph, trueGraph);
        }

        System.out.println("Search Graph " + searchGraph);

        try {
            Graph graph = new MarshalledObject<>(searchGraph).get();
            LayoutUtil.defaultLayout(graph);
            LayoutUtil.fruchtermanReingoldLayout(graph);
            setResultGraph(graph);
            setClusters(MimUtils.convertToClusters(graph, getData().getVariables()));
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException(e);
        }
    }

    private void rename(Graph searchGraph, List> partition, List variableNames) {
        for (Node node : searchGraph.getNodes()) {
            if (!(node.getNodeType() == NodeType.LATENT)) {
                continue;
            }

            List children = searchGraph.getChildren(node);
            ReidentifyVariables.getLatents(searchGraph).forEach(children::remove);

            for (int i = 0; i < partition.size(); i++) {
                if (new HashSet<>(partition.get(i)).equals(new HashSet<>(children))) {
                    node.setName(variableNames.get(i));
                }
            }
        }
    }

    public Graph getGraph() {
        return getResultGraph();
    }

    public java.util.List getVariables() {
        List latents = new ArrayList<>();

        for (String name : getVariableNames()) {
            Node node = new ContinuousVariable(name);
            node.setNodeType(NodeType.LATENT);
            latents.add(node);
        }

        return latents;
    }

    public List getVariableNames() {
        List> partition = ClusterUtils.clustersToPartition(getClusters(),
                getData().getVariables());
        return ClusterUtils.generateLatentNames(partition.size());
    }
}









© 2015 - 2025 Weber Informatics LLC | Privacy Policy