All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetradapp.model.DataWrapper Maven / Gradle / Ivy

There is a newer version: 7.6.6
Show newest version
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below.       //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006,       //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard        //
// Scheines, Joseph Ramsey, and Clark Glymour.                               //
//                                                                           //
// This program is free software; you can redistribute it and/or modify      //
// it under the terms of the GNU General Public License as published by      //
// the Free Software Foundation; either version 2 of the License, or         //
// (at your option) any later version.                                       //
//                                                                           //
// This program is distributed in the hope that it will be useful,           //
// but WITHOUT ANY WARRANTY; without even the implied warranty of            //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the             //
// GNU General Public License for more details.                              //
//                                                                           //
// You should have received a copy of the GNU General Public License         //
// along with this program; if not, write to the Free Software               //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA //
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetradapp.model;

import edu.cmu.tetrad.data.*;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.session.DoNotAddOldModel;
import edu.cmu.tetrad.session.SimulationParamsSource;
import edu.cmu.tetrad.util.Parameters;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serial;
import java.util.*;

/**
 * Wraps a DataModel as a model class for a Session, providing constructors for the parents of Tetrad that are specified
 * by Tetrad.
 *
 * @author josephramsey
 */
public class DataWrapper implements KnowledgeEditable, KnowledgeBoxInput,
        DoNotAddOldModel, SimulationParamsSource, MultipleDataSource {

    private static final long serialVersionUID = 23L;
    /**
     * Maps columns to discretization specs so that user's work is not forgotten from one editing of the same data set
     * to the next.
     *
     * @serial Cannot be null.
     */
    private final Map discretizationSpecs = new HashMap();
    /**
     * @serial Can be null.
     */
    private String name;
    private DataModelList dataModelList;
    /**
     * Stores a reference to the source workbench, if there is one.
     *
     * @serial Can be null.
     */
    private Graph sourceGraph;

    /**
     * The parameters being edited.
     */
    private Parameters parameters;
    private Map allParamSettings;

    //==============================CONSTRUCTORS===========================//
    protected DataWrapper() {
        setDataModel(new BoxDataSet(new VerticalDoubleDataBox(0, 0), new LinkedList<>()));
        this.parameters = new Parameters();
    }

    /**
     * Constructs a data wrapper using a new DataSet as data model.
     */
    public DataWrapper(Parameters parameters) {
        setDataModel(new BoxDataSet(new VerticalDoubleDataBox(0, 0), new LinkedList<>()));
        this.parameters = parameters;
    }

    public DataWrapper(Simulation wrapper, Parameters parameters) {
        this.name = wrapper.getName();
        this.dataModelList = new DataModelList();

        for (DataModel model : wrapper.getDataModels()) {
            if (model instanceof DataSet) {
                this.dataModelList.add(((DataSet) model).copy());
            } else if (model instanceof CorrelationMatrix) {
                this.dataModelList.add(new CorrelationMatrix((CorrelationMatrix) model));
            } else if (model instanceof CovarianceMatrix) {
                this.dataModelList.add(new CovarianceMatrix((CovarianceMatrix) model));
            } else {
                throw new IllegalArgumentException();
            }
        }

        this.dataModelList = wrapper.getDataModelList();
        this.parameters = parameters;
    }

    /**
     * Copy constructor.
     *
     * @param wrapper    the data wrapper to copy.
     * @param parameters the parameters to use.
     */
    public DataWrapper(DataWrapper wrapper, Parameters parameters) {
        this.name = wrapper.name;
        this.parameters = new Parameters(parameters);
        DataModelList dataModelList = new DataModelList();
        int selected = -1;

        for (int i = 0; i < wrapper.getDataModelList().size(); i++) {
            DataModel model = wrapper.getDataModelList().get(i);

            if (model instanceof DataSet) {
                dataModelList.add(((DataSet) model).copy());
            } else if (model instanceof CorrelationMatrix) {
                dataModelList.add(new CorrelationMatrix((CorrelationMatrix) model));
            } else if (model instanceof CovarianceMatrix) {
                dataModelList.add(new CovarianceMatrix((CovarianceMatrix) model));
            } else {
                throw new IllegalArgumentException();
            }

            if (model.equals(wrapper.getDataModelList().getSelectedModel())) {
                selected = i;
            }
        }

        if (selected > -1) {
            dataModelList.setSelectedModel(dataModelList.get(selected));
        }

        if (wrapper.sourceGraph != null) {
            this.sourceGraph = new EdgeListGraph(wrapper.sourceGraph);
        }

        this.dataModelList = dataModelList;

        LogDataUtils.logDataModelList("Standalone data set.", getDataModelList());
    }

    /**
     * Constructs a data wrapper using a new DataSet as data model.
     *
     * @param dataSet the data set to use.
     */
    public DataWrapper(DataSet dataSet) {
        setDataModel(dataSet);
    }

    /**
     * Constructs a data wrapper using a new DataSet as data model.
     *
     * @param graph      the graph to use.
     * @param parameters the parameters to use.
     */
    public DataWrapper(Graph graph, Parameters parameters) {
        if (graph == null) {
            throw new NullPointerException();
        }

        this.parameters = new Parameters(parameters);

        List nodes = graph.getNodes();
        List variables = new LinkedList<>();

        for (Object node1 : nodes) {
            Node node = (Node) node1;
            String name = node.getName();
            NodeType nodetype = node.getNodeType();
            if (nodetype == NodeType.MEASURED) {
                ContinuousVariable var = new ContinuousVariable(name);
                variables.add(var);
            }
        }

        DataSet dataSet = new BoxDataSet(new VerticalDoubleDataBox(0, variables.size()), variables);
        DataModelList dataModelList = new DataModelList();
        dataModelList.add(dataSet);
        this.dataModelList = dataModelList;
    }

    /**
     * Constructs a data wrapper using a new DataSet as data model.
     *
     * @param dagWrapper the DAG to use.
     * @param parameters the parameters to use.
     */
    public DataWrapper(DagWrapper dagWrapper, Parameters parameters) {
        this(dagWrapper.getDag(), parameters);
    }

    /**
     * Constructs a data wrapper using a new DataSet as data model.
     *
     * @param wrapper    the SEM graph to use.
     * @param parameters the parameters to use.
     */
    public DataWrapper(SemGraphWrapper wrapper, Parameters parameters) {
        this(wrapper.getGraph(), parameters);
    }

    /**
     * Constructs a data wrapper using a new DataSet as data model.
     *
     * @param wrapper    the SEM graph to use.
     * @param parameters the parameters to use.
     */
    public DataWrapper(GraphWrapper wrapper, Parameters parameters) {
        this(wrapper.getGraph(), parameters);
    }

    /**
     * Constructs a data wrapper using a new DataSet as data model.
     *
     * @param regression the regression to use.
     * @param wrapper    the data model to use.
     * @param parameters the parameters to use.
     */
    public DataWrapper(RegressionRunner regression, DataWrapper wrapper, Parameters parameters) {
        this(regression.getResult(), (DataSet) Objects.requireNonNull(wrapper.getDataModelList().getSelectedModel()),
                parameters);
    }

    /**
     * Constructs a data wrapper using a new DataSet as data model.
     *
     * @param regression the regression to use.
     * @param wrapper    the data model to use.
     * @param parameters the parameters to use.
     */
    public DataWrapper(RegressionRunner regression, Simulation wrapper, Parameters parameters) {
        this(regression.getResult(), (DataSet) Objects.requireNonNull(wrapper.getDataModelList().getSelectedModel()),
                parameters);
    }

    /**
     * Constructs a data wrapper using a new DataSet as data model.
     *
     * @param result     the regression result to use.
     * @param data       the data to use.
     * @param parameters the parameters to use.
     */
    public DataWrapper(RegressionResult result, DataSet data, Parameters parameters) {
        this.parameters = new Parameters(parameters);

        DataSet data2 = data.copy();
        String predictedVariable = nextVariableName("Pred", data);
        data2.addVariable(new ContinuousVariable(predictedVariable));

        String[] regressorNames = result.getRegressorNames();

        for (int i = 0; i < data.getNumRows(); i++) {
            double[] x = new double[regressorNames.length];

            for (int j = 0; j < regressorNames.length; j++) {
                Node variable = data.getVariable(regressorNames[j]);

                if (variable == null) {
                    throw new NullPointerException("Variable " + variable + " doesn't "
                            + "exist in the input data.");
                }

                if (!(variable instanceof ContinuousVariable)) {
                    throw new IllegalArgumentException("Expecting a continuous variable: " + variable);
                }

                x[j] = data.getDouble(i, data.getColumn(variable));
            }

            double yHat = result.getPredictedValue(x);
            data2.setDouble(i, data2.getColumn(data2.getVariable(predictedVariable)), yHat);
        }

        DataModelList dataModelList = new DataModelList();
        dataModelList.add(data2);
        this.dataModelList = dataModelList;
    }

    /**
     * Constructs a data wrapper using a new DataSet as data model.
     *
     * @param mimBuild   the mim build to use.
     * @param parameters the parameters to use.
     */
    public DataWrapper(MimBuildRunner mimBuild, Parameters parameters) {
        this.parameters = new Parameters(parameters);

        ICovarianceMatrix cov = mimBuild.getCovMatrix();

        DataModelList dataModelList = new DataModelList();
        dataModelList.add(cov);
        this.dataModelList = dataModelList;
    }

    /**
     * Generates a simple exemplar of this class to test serialization.
     */
    public static PcRunner serializableInstance() {
        return PcRunner.serializableInstance();
    }

    /**
     * Given base b (a String), returns the first node in the sequence "b1", "b2", "b3", etc., which is not already the
     * name of a node in the workbench.
     *
     * @param base the base string.
     * @return the first string in the sequence not already being used.
     */
    private String nextVariableName(String base, DataSet data) {

        // Variable names should start with "1."
        int i = -1;
        String name = "?";

        loop:
        while (true) {
            ++i;

            if (i == 0) {
                name = base;
            } else {
                name = base + i;
            }

            for (Node node1 : data.getVariables()) {
                if (node1.getName().equals(name)) {
                    continue loop;
                }
            }

            break;
        }

        return name;
    }

    /**
     * Stores a reference to the data model being wrapped.
     *
     * @return the list of models.
     */
    public DataModelList getDataModelList() {
        return this.dataModelList;
    }

    /**
     * Set the data model list.
     *
     * @param dataModelList the data model list to set.
     */
    public void setDataModelList(DataModelList dataModelList) {
        if (dataModelList == null) {
            throw new NullPointerException("Data model list not provided.");
        }
        this.dataModelList = dataModelList;
    }

    /**
     * @return the data model for this wrapper.
     */
    public List getDataModels() {
        return new ArrayList<>(this.dataModelList);
    }

    /**
     * @return the selected data model for this wrapper.
     */
    public DataModel getSelectedDataModel() {
        DataModelList modelList = getDataModelList();
        return modelList.getSelectedModel();
    }

    /**
     * Sets the data model.
     *
     * @param dataModel the data model to set.
     */
    public void setDataModel(DataModel dataModel) {
        if (dataModel == null) {
            dataModel = new BoxDataSet(new VerticalDoubleDataBox(0, 0), new LinkedList<>());
        }

        if (dataModel instanceof DataModelList) {
            this.dataModelList = (DataModelList) dataModel;
        } else {
            DataModelList dataModelList = new DataModelList();
            dataModelList.add(dataModel);
            this.dataModelList = dataModelList;
        }
    }

    /**
     * @return the knowledge for this wrapper.
     */
    public Knowledge getKnowledge() {
        return getSelectedDataModel().getKnowledge().copy();
    }

    /**
     * Sets knowledge to a copy of the given object.
     *
     * @param knowledge the knowledge to set.
     */
    public void setKnowledge(Knowledge knowledge) {
        getSelectedDataModel().setKnowledge(knowledge.copy());
    }

    /**
     * @return the variable names of the selected data model.
     */
    public List getVarNames() {
        return getSelectedDataModel().getVariableNames();
    }

    /**
     * @return the source graph.
     */
    public Graph getSourceGraph() {
        return this.sourceGraph;
    }

    /**
     * Sets the source graph.
     *
     * @param sourceGraph the source graph to set.
     */
    protected void setSourceGraph(Graph sourceGraph) {
        this.sourceGraph = sourceGraph;
    }

    /**
     * @return the result graph.
     */
    public Graph getResultGraph() {
        return getSourceGraph();
    }

    /**
     * @return the variables, in order.
     */
    public List getVariables() {
        return this.getSelectedDataModel().getVariables();
    }

    /**
     * Adds semantic checks to the default deserialization method. This method must have the standard signature for a
     * readObject method, and the body of the method must begin with "s.defaultReadObject();". Other than that, any
     * semantic checks can be specified and do not need to stay the same from version to version. A readObject method of
     * this form may be added to any class, even if Tetrad sessions were previously saved out using a version of the
     * class that didn't include it. (That's what the "s.defaultReadObject();" is for. See J. Bloch, Effective Java, for
     * help.
     */
    @Serial
    private void readObject(ObjectInputStream s)
            throws IOException, ClassNotFoundException {
        s.defaultReadObject();
    }

    /**
     * @return the name of the data wrapper.
     */
    public String getName() {
        return this.name;
    }

    /**
     * Sets the name of the data wrapper.
     *
     * @param name the name to set.
     */
    public void setName(String name) {
        this.name = name;
    }

    /**
     * Returns the parameters being edited.
     *
     * @return the parameters being edited.
     */
    public Parameters getParams() {
        return this.parameters;
    }

    /**
     * Sets the parameters being edited.
     *
     * @param parameters the parameters to set.
     */
    public void setParameters(Parameters parameters) {
        this.parameters = parameters;
    }

    /**
     * Returns the variable names.
     *
     * @return the variable names.
     */
    public List getVariableNames() {
        List variableNames = new ArrayList<>();
        for (Node n : getVariables()) {
            variableNames.add(n.getName());
        }
        return variableNames;
    }

    /**
     * Returns the parameter setting map.
     *
     * @return the parameter setting map.
     */
    @Override
    public Map getParamSettings() {
        Map paramSettings = new HashMap<>();

        if (this.dataModelList == null) {
            System.out.println();
        }

        if (this.dataModelList.size() > 1) {
            paramSettings.put("# Datasets", Integer.toString(this.dataModelList.size()));
        } else {
            DataModel dataModel = this.dataModelList.get(0);

            if (dataModel instanceof CovarianceMatrix) {
                paramSettings.put("# Vars", Integer.toString(((CovarianceMatrix) dataModel).getDimension()));
                paramSettings.put("N", Integer.toString(((CovarianceMatrix) dataModel).getSampleSize()));
            } else {
                paramSettings.put("# Vars", Integer.toString(((DataSet) dataModel).getNumColumns()));
                paramSettings.put("N", Integer.toString(((DataSet) dataModel).getNumRows()));
            }
        }

        return paramSettings;
    }

    /**
     * Returns the parameter setting map.
     *
     * @return the parameter setting map.
     */
    @Override
    public Map getAllParamSettings() {
        return this.allParamSettings;
    }

    /**
     * Sets the parameter setting map.
     *
     * @param paramSettings the parameter setting map to set.
     */
    @Override
    public void setAllParamSettings(Map paramSettings) {
        this.allParamSettings = paramSettings;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy