All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.cmu.tetrad.simulation.HsimContinuous Maven / Gradle / Ivy

The newest version!
package edu.cmu.tetrad.simulation;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.sem.*;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
 * Created by Erich on 3/27/2016.
 *
 * @author josephramsey
 * @version $Id: $Id
 */
public class HsimContinuous {
    //Dag mydag, Set simnodes, DataSet data
    private boolean verbose;
    private Dag mydag;
    private Set simnodes;
    private DataSet data;

    //************Constructors***************//

    /**
     * 

Constructor for HsimContinuous.

* * @param thedag a {@link edu.cmu.tetrad.graph.Dag} object * @param thesimnodes a {@link java.util.Set} object * @param thedata a {@link edu.cmu.tetrad.data.DataSet} object */ public HsimContinuous(Dag thedag, Set thesimnodes, DataSet thedata) { if (thedata.isDiscrete()) { throw new IllegalArgumentException( "HsimContinuous only accepts continuous data."); } if (thedag == null) { throw new IllegalArgumentException( "Hsim needs a Dag."); } if (thesimnodes == null) { throw new IllegalArgumentException( "Please specify the nodes Hsim will resimulate."); } // (Optional: Eventually want options for search methods for picking out the DAG) setVerbose(); setDag(thedag); setData(thedata); setSimnodes(thesimnodes); } //**************Public methods***********************// // Calculates the Markov blanket of a node in a graph. private static Set mb(Graph graph, Node z) { Set mb = new HashSet<>(graph.getAdjacentNodes(z)); for (Node c : graph.getChildren(z)) { for (Node p : graph.getParents(c)) { //make sure you don't add z itslef to the markov blanket if (p != z) { mb.add(p); } } } return mb; } //========================================PRIVATE METHODS====================================// /** *

hybridsimulate.

* * @return a {@link edu.cmu.tetrad.data.DataSet} object */ public DataSet hybridsimulate() { // Find Markov Blankets for resimulated variables**/ // this needs to be made general, rather than only for two specific names nodes**/ if (this.verbose) System.out.println("Finding a Markov blanket for resimulated nodes"); Set mbAll = new HashSet<>(); //initialize an empty set of nodes; Set mbAdd = new HashSet<>(); //init set for adding for (Node node : this.simnodes) { mbAdd = HsimContinuous.mb(this.mydag, node); //find mb for that node mbAll.addAll(mbAdd); //use .addAll to add this mb to the set } //make sure all the simnodes are in mbAll! a disconnected node could cause errors later otherwise mbAll.addAll(this.simnodes); if (this.verbose) System.out.println("The Markov Blanket is " + mbAll); // Find the subgraph for the resimulated variables and their markov blanket**/ if (this.verbose) System.out.println("Finding a subgraph over the Markov Blanket and Resimulated Nodes"); //need a List as input for subgraph method, but mbAll is a Set List mbListAll = new ArrayList<>(mbAll); Graph subgraph = this.mydag.subgraph(mbListAll); // Learn an instantiated model over the subgraph**/ if (this.verbose) System.out.println("Learning an instantiated model for the subgraph"); //Do this step continuous instead of discrete: //learn a dirichlet IM for the subgraph using dataSet SemPm subgraphPM = new SemPm(subgraph); SemEstimator subgraphEstimator = new SemEstimator(this.data, subgraphPM); SemIm subgraphIM = subgraphEstimator.estimate(); //if (verbose) System.out.println(fittedsubgraphIM.getVariable()); // Use the learned instantiated subgraph model to create the resimulated data**/ if (this.verbose) System.out.println("Starting resimulation loop"); //Use the BayesIM to learn the conditional marginal distribution of X given mbAll //first construct the updater, using RowSummingExactUpdater(BayesIm bayesIm, Evidence evidence) //To use that, need to make an Evidence, which in this case is the values of the MB //will need to make a new Evidence, and perform the updater, for every row of data. //In order to make an Evidence, need to make a Proposition //then modify the proposition to fix the values of all conditioned variables using: //setCategory(int variable, int category) //since setCategory only takes the int values of variable and category, need to figure out what those are //can figure out those from the proposition's variable source using: //public int getNodeIndex(String name) //and //public int getCategoryIndex(String nodeName, String category) //want the causal ordering of the subgraph: //List subgraphOrdering = GraphUtils.getCausalOrdering(subgraph); //loop through each row of the data set, conditioning and drawing values each time. for (int row = 0; row < this.data.getNumRows(); row++) { //create a new evidence object SemEvidence evidence = new SemEvidence(subgraphIM); //need to define the set of variables being conditioned upon. Start with the outer set of MB //need to remove the whole set of starters, not just some X and Y... how do? loop a .remove? for (Node node : this.simnodes) { mbAll.remove(node); } //loop through all the nodes being conditioned upon, and set their values in the evidence prop for (Node i : mbAll) { //int nodeIndex = evidence.getNodeIndex(i.getName()); int nodeColumn = this.data.getColumn(i); evidence.getProposition().setValue(i, this.data.getDouble(row, nodeColumn)); } //use the new Evidence object to create the updater SemUpdater conditionUpdate = new SemUpdater(subgraphIM); conditionUpdate.setEvidence(evidence); SemIm updatedIM = conditionUpdate.getUpdatedSemIm(); //draw values for the node we're resimming DataSet newValues = updatedIM.simulateData(1, false); //DataSet newValues = updatedIM.simulateDataRecursive(1,false); //take these new simnodes values and replace the old values in the data set with them for (Node node : this.simnodes) { //if (verbose) System.out.println(data.getInt(row,data.getColumn(nodeX)) + " old vs new " + newXvalue); this.data.setDouble(row, this.data.getColumn(node), newValues.getDouble(0, newValues.getColumn(node))); //if (verbose) System.out.println(" and again?: " + data.getInt(row,data.getColumn(nodeX)) + " old vs new " + newXvalue); } } return this.data; } /// *********Private methods for setting private variables***********// private void setVerbose() { this.verbose = false; } private void setDag(Dag thedag) { this.mydag = thedag; } private void setSimnodes(Set thenodes) { this.simnodes = thenodes; } private void setData(DataSet thedata) { this.data = thedata; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy