edu.cmu.tetrad.search.utils.ShiftSearch Maven / Gradle / Ivy
The newest version!
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below. //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard //
// Scheines, Joseph Ramsey, and Clark Glymour. //
// //
// This program is free software; you can redistribute it and/or modify //
// it under the terms of the GNU General Public License as published by //
// the Free Software Foundation; either version 2 of the License, or //
// (at your option) any later version. //
// //
// This program is distributed in the hope that it will be useful, //
// but WITHOUT ANY WARRANTY; without even the implied warranty of //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //
// GNU General Public License for more details. //
// //
// You should have received a copy of the GNU General Public License //
// along with this program; if not, write to the Free Software //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA //
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetrad.search.utils;
import edu.cmu.tetrad.data.*;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.Fges;
import edu.cmu.tetrad.search.score.ImagesScore;
import edu.cmu.tetrad.search.score.Score;
import edu.cmu.tetrad.search.score.SemBicScore;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.SublistGenerator;
import org.apache.commons.math3.util.FastMath;
import java.io.OutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.List;
/**
* Tries to find a good shifting of variables to minimize average BIC for
* time-series data. The idea is that the data one is presented with may have the variables temporally shifted with
* respect to one another. ShiftSearch attempts to find a shifting of the variables that reduces this temporal
* shifting.
*
* @author josephramsey
* @version $Id: $Id
*/
public class ShiftSearch {
private final List dataSets;
private int maxShift = 2;
private Knowledge knowledge = new Knowledge();
private int c = 4;
private int maxNumShifts;
private transient PrintStream out = System.out;
private boolean scheduleStop;
private boolean forwardSearch;
private boolean precomputeCovariances = false;
/**
* Constructor for ShiftSearch.
*
* @param dataSets a {@link java.util.List} object
*/
public ShiftSearch(List dataSets) {
this.dataSets = dataSets;
}
/**
* search.
*
* @return an array of {@link int} objects
*/
public int[] search() {
if (this.maxShift < 1) {
throw new IllegalStateException("Max shift should be >= 1: " + this.maxShift);
}
int numVars = ((DataSet) this.dataSets.get(0)).getNumColumns();
List nodes = this.dataSets.get(0).getVariables();
int[] shifts;
int[] bestshifts = new int[numVars];
int maxNumRows = ((DataSet) this.dataSets.get(0)).getNumRows() - this.maxShift;
double b = getAvgBic(this.dataSets);
printShifts(bestshifts, b, nodes);
SublistGenerator generator = new SublistGenerator(nodes.size(), getMaxNumShifts());
int[] choice;
while ((choice = generator.next()) != null) {
shifts = new int[nodes.size()];
double zSize = FastMath.pow(getMaxShift(), choice.length);
int iIndex = this.dataSets.get(0).getVariables().indexOf(this.dataSets.get(0).getVariable("I"));
for (int z = 0; z < zSize; z++) {
if (this.scheduleStop) break;
int _z = z;
for (int j : choice) {
if (j == iIndex) {
continue;
}
shifts[j] = (_z % (getMaxShift()) + 1);
if (!this.forwardSearch) {
shifts[j] = -shifts[j];
}
_z /= getMaxShift();
}
List _shiftedDataSets = getShiftedDataSets(shifts, maxNumRows);
double _b = getAvgBic(_shiftedDataSets);
if (_b < 0.999 * b) {
b = _b;
printShifts(shifts, b, nodes);
System.arraycopy(shifts, 0, bestshifts, 0, shifts.length);
}
}
}
println("\nShifts with the lowest BIC score: ");
printShifts(bestshifts, b, nodes);
return bestshifts;
}
/**
* Getter for the field maxShift
.
*
* @return a int
*/
public int getMaxShift() {
return this.maxShift;
}
/**
* Setter for the field maxShift
.
*
* @param maxShift a int
*/
public void setMaxShift(int maxShift) {
this.maxShift = maxShift;
}
/**
* Getter for the field knowledge
.
*
* @return a {@link edu.cmu.tetrad.data.Knowledge} object
*/
public Knowledge getKnowledge() {
return this.knowledge;
}
/**
* Setter for the field knowledge
.
*
* @param knowledge a {@link edu.cmu.tetrad.data.Knowledge} object
*/
public void setKnowledge(Knowledge knowledge) {
this.knowledge = new Knowledge(knowledge);
}
/**
* Getter for the field c
.
*
* @return a int
*/
public int getC() {
return this.c;
}
/**
* Setter for the field c
.
*
* @param c a int
*/
public void setC(int c) {
this.c = c;
}
/**
* Getter for the field maxNumShifts
.
*
* @return a int
*/
public int getMaxNumShifts() {
return this.maxNumShifts;
}
/**
* Setter for the field maxNumShifts
.
*
* @param maxNumShifts a int
*/
public void setMaxNumShifts(int maxNumShifts) {
this.maxNumShifts = maxNumShifts;
}
/**
* Setter for the field out
.
*
* @param out a {@link java.io.OutputStream} object
*/
public void setOut(OutputStream out) {
this.out = new PrintStream(out);
}
/**
* stop.
*/
public void stop() {
this.scheduleStop = true;
}
/**
* Setter for the field forwardSearch
.
*
* @param forwardSearch a boolean
*/
public void setForwardSearch(boolean forwardSearch) {
this.forwardSearch = forwardSearch;
}
private void printShifts(int[] shifts, double b, List nodes) {
StringBuilder buf = new StringBuilder();
for (int i = 0; i < shifts.length; i++) {
buf.append(nodes.get(i)).append("=").append(shifts[i]).append(" ");
}
buf.append(b);
println(buf.toString());
}
private void println(String s) {
System.out.println(s);
if (this.out != null) {
this.out.println(s);
this.out.flush();
}
}
private List getShiftedDataSets(int[] shifts, int maxNumRows) {
List shiftedDataSets2 = new ArrayList<>();
for (DataModel dataSet : this.dataSets) {
DataSet shiftedData = TsUtils.createShiftedData((DataSet) dataSet, shifts);
shiftedDataSets2.add(shiftedData);
}
return ensureNumRows(shiftedDataSets2, maxNumRows);
}
private List ensureNumRows(List dataSets, int numRows) {
List truncatedData = new ArrayList<>();
for (DataModel _dataSet : dataSets) {
DataSet dataSet = (DataSet) _dataSet;
Matrix mat = dataSet.getDoubleData();
Matrix mat2 = mat.getPart(0, numRows - 1, 0, mat.getNumColumns() - 1);
truncatedData.add(new BoxDataSet(new DoubleDataBox(mat2.toArray()), dataSet.getVariables()));
}
return truncatedData;
}
private double getAvgBic(List dataSets) {
List scores = new ArrayList<>();
for (DataModel dataSet : dataSets) {
SemBicScore _score = new SemBicScore((DataSet) dataSet, precomputeCovariances);
scores.add(_score);
}
ImagesScore imagesScore = new ImagesScore(scores);
Fges images = new Fges(imagesScore);
images.setKnowledge(this.knowledge);
images.search();
return -images.getModelScore() / dataSets.size();
}
/**
* Setter for the field precomputeCovariances
.
*
* @param precomputeCovariances a boolean
*/
public void setPrecomputeCovariances(boolean precomputeCovariances) {
this.precomputeCovariances = precomputeCovariances;
}
}