org.integratedmodelling.engine.modelling.bayes.gn.GenieBayesianNetwork Maven / Gradle / Ivy
The newest version!
/*******************************************************************************
* Copyright (C) 2007, 2015:
*
* - Ferdinando Villa
* - integratedmodelling.org
* - any other authors listed in @author annotations
*
* All rights reserved. This file is part of the k.LAB software suite,
* meant to enable modular, collaborative, integrated
* development of interoperable data and model components. For
* details, see http://integratedmodelling.org.
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the Affero General Public License
* Version 3 or any later version.
*
* This program is distributed in the hope that it will be useful,
* but without any warranty; without even the implied warranty of
* merchantability or fitness for a particular purpose. See the
* Affero General Public License for more details.
*
* You should have received a copy of the Affero General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
* The license is also available at: https://www.gnu.org/licenses/agpl.html
*******************************************************************************/
package org.integratedmodelling.engine.modelling.bayes.gn;
import java.io.File;
import org.integratedmodelling.api.knowledge.IConcept;
import org.integratedmodelling.engine.modelling.bayes.IBayesianInference;
import org.integratedmodelling.engine.modelling.bayes.IBayesianNetwork;
import org.integratedmodelling.exceptions.KlabException;
import org.integratedmodelling.exceptions.KlabIOException;
import org.integratedmodelling.exceptions.KlabValidationException;
import smile.Network;
import smile.SMILEException;
import smile.learning.DataMatch;
import smile.learning.DataSet;
import smile.learning.EM;
public class GenieBayesianNetwork implements IBayesianNetwork {
/*
* we keep one network as "prototype" and we use it for the first
* inference object.
*/
Network prototype;
IConcept observable;
boolean used = false;
String input = null;
public GenieBayesianNetwork(File in) throws KlabIOException {
prototype = new Network();
try {
this.prototype.readFile(this.input = in.toString());
} catch (Exception e) {
throw new KlabIOException("GENIE import: reading " + in + ": " + e.getMessage());
}
}
public GenieBayesianNetwork(Network network, IConcept observable, String input) {
this.prototype = network;
this.observable = observable;
this.input = input;
}
@Override
public IBayesianInference getInference() {
// TODO Auto-generated method stub
if (!used) {
return new GenieBayesianInference(prototype);
}
Network net = new Network();
/*
* one has been succesfully read already, don't capture exceptions.
*/
net.readFile(this.input);
return new GenieBayesianInference(net);
}
@Override
public int getNodeCount() {
return prototype.getNodeCount();
}
@Override
public String[] getAllNodeIds() {
return prototype.getAllNodeIds();
}
@Override
public int getOutcomeCount(String nodeId) {
return prototype.getOutcomeCount(nodeId);
}
@Override
public String getOutcomeId(String nodeId, int outcomeIndex) {
return prototype.getOutcomeId(nodeId, outcomeIndex);
}
@Override
public String[] getParentIds(String nodeId) {
return prototype.getParentIds(nodeId);
}
@Override
public String[] getChildIds(String nodeId) {
return prototype.getChildIds(nodeId);
}
@Override
public String[] getOutcomeIds(String nodeId) {
return prototype.getOutcomeIds(nodeId);
}
@Override
public String getName() {
return prototype.getName();
}
@Override
public IBayesianNetwork train(File observations, String method) throws KlabException {
Network network = new Network();
try {
network.readFile(this.input);
} catch (Exception e) {
throw new KlabIOException("GENIE import: reading " + input + ": " + e.getMessage());
}
DataSet dset = new DataSet();
dset.readFile(observations.toString(), "*");
dset.matchNetwork(network);
DataMatch[] dm = new DataMatch[dset.getVariableCount()];
for (int i = 0; i < dset.getVariableCount(); i++) {
String nodeId = dset.getVariableId(i);
int node = network.getNode(nodeId);
// TODO check this bizarre slice parameter
dm[i] = new DataMatch(i, node, 0);
}
try {
if (method.equals("EM")) {
EM em = new EM();
em.learn(dset, network, dm);
} // TODO remaining methods
} catch (SMILEException e) {
throw new KlabValidationException(e);
}
return new GenieBayesianNetwork(network, observable, input);
}
@Override
public void write(File modelFile) throws KlabIOException {
try {
this.prototype.writeFile(modelFile.toString());
} catch (SMILEException e) {
throw new KlabIOException(e);
}
}
@Override
public boolean isLeaf(String nodeId) {
// TODO Auto-generated method stub
String[] ids = getParentIds(nodeId);
return ids == null || ids.length == 0;
}
}