com.meliorbis.economics.model.ModelRunner Maven / Gradle / Ivy
Show all versions of ModelSolver Show documentation
package com.meliorbis.economics.model;
import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.commons.cli.BasicParser;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.lang.ArrayUtils;
import org.picocontainer.MutablePicoContainer;
import org.picocontainer.PicoBuilder;
import org.picocontainer.injectors.MultiInjection;
import com.meliorbis.economics.aggregate.AggregateProblemSolver;
import com.meliorbis.economics.individual.IndividualProblemSolver;
import com.meliorbis.economics.infrastructure.AbstractModel;
import com.meliorbis.economics.infrastructure.Base;
import com.meliorbis.economics.infrastructure.Solver;
import com.meliorbis.economics.infrastructure.SolverException;
import com.meliorbis.economics.infrastructure.simulation.AggregateSimulationObserver;
import com.meliorbis.economics.infrastructure.simulation.DensitySavingSimulationObserver;
import com.meliorbis.economics.infrastructure.simulation.DiscretisedDistribution;
import com.meliorbis.economics.infrastructure.simulation.DiscretisedDistributionSimulatorImpl;
import com.meliorbis.economics.infrastructure.simulation.PeriodAggregateState;
import com.meliorbis.economics.infrastructure.simulation.RASimulator;
import com.meliorbis.economics.infrastructure.simulation.RepresentativeAgentSimState;
import com.meliorbis.economics.infrastructure.simulation.SimState;
import com.meliorbis.economics.infrastructure.simulation.SimulationObserver;
import com.meliorbis.economics.infrastructure.simulation.SimulationResults;
import com.meliorbis.economics.infrastructure.simulation.Simulator;
import com.meliorbis.economics.infrastructure.simulation.SimulatorException;
import com.meliorbis.numerics.fixedpoint.FixedPointValueDelegate;
import com.meliorbis.numerics.fixedpoint.MultivariateBoundedFPFinder;
import com.meliorbis.numerics.function.MultiVariateVectorFunction;
import com.meliorbis.numerics.generic.impl.GenericBlockedArray;
import com.meliorbis.numerics.generic.primitives.DoubleArray;
import com.meliorbis.numerics.generic.primitives.impl.DoubleArrayFunctions;
import com.meliorbis.numerics.io.NumericsReader;
import com.meliorbis.numerics.io.NumericsWriter;
import com.meliorbis.numerics.io.NumericsWriterFactory;
import com.meliorbis.numerics.io.csv.CSVWriterFactory;
import com.meliorbis.numerics.io.matlab.MatlabWriterFactory;
import com.meliorbis.utils.Pair;
import com.meliorbis.utils.Timer;
import com.meliorbis.utils.Timer.Stoppable;
import com.meliorbis.utils.Utils;
/**
* The abstract base class for objects that control configuration, solution and simulation of models
*
* @author Tobias Grasl
*
* @param The class of Model this runner is for
* @param The configuration class this runner is for
* @param The state class this runner is for
*/
public abstract class ModelRunner, C extends ModelConfig, S extends State> extends Base
{
private static final String PERIODS_OPTION = "periods";
private static final String SOLUTION_DIR_OPTION = "stateDir";
private static final String SIM_DIR_OPTION = "simDir";
private static final String CONTINUE_OPTION = "initialState";
private static final String SHOCK_DIR_OPTION = "useShocks";
private static final String SIMULATE_OPTION = "simulate";
private static final String SOLVE_OPTION = "solve";
private static final String ADD_CLASS = "addClass";
private static final String AGG_SIM_OPTION = "aggOnly";
private static final String SAVE_DENSITIES = "saveSimDensities";
private static final String WRITE_MATLAB = "writeMatlab";
private static final Logger LOG = Logger.getLogger(ModelRunner.class.getName());
private Solver _solver;
final private MutablePicoContainer _container;
private DiscretisedDistributionSimulatorImpl _simulator;
private M _modelInstance = null;
private File _continueDir;
private File _solutionDir;
private boolean _saveDensities;
public ModelRunner()
{
_container = new PicoBuilder().withCaching().withComponentFactory(new MultiInjection()).build();;
_container.addComponent(getNumerics());
_container.addComponent(Solver.class);
_container.addComponent(DiscretisedDistributionSimulatorImpl.class);
_container.addComponent(RASimulator.class);
_container.addComponent(AggregateSimulationObserver.class);
}
/**
* Sets the path from which initial state will be read
*
* @param path_ The path from which to read starting state
*/
protected void setInitialStatePath(String path_)
{
_continueDir = new File(path_);
}
public String solveModel(Class extends M> modelClass_, C config_)
{
S state = null;
String solvedDir = null;
DiscretisedDistribution simState = null;
M model = null;
try
{
model = instantiateModel(modelClass_, config_);
if (!config_.hasAggUncertainty())
{
if(!config_.hasIndUncertainty())
{
state = determineSolution(model, false);
System.out.println("Found consistent solution");
DoubleArray> aggregateTransition = state.getAggregateTransition();
int[] at = Utils.repeatArray(0, aggregateTransition.numberOfDimensions()-1);
for (int i = 0; i < config_.getAggregateEndogenousStateCount(); i++)
{
at[i] = -1;
}
DoubleArray> reducedTrans = aggregateTransition.at(at);
DoubleArray> fp = DoubleArrayFunctions.interpolateFixedPoint(reducedTrans, (DoubleArray>[])config_.getAggregateEndogenousStates().toArray());
System.out.println("Fixed Point: " + fp);
}
else
{
System.out.println("Model has no aggregate risk. Finding Steady State");
Pair results = findNASteadyState(model, modelClass_, config_);
state = results.getLeft();
simState = results.getRight();
}
} else
{
state = determineSolution(model, false);
}
} catch (Throwable e)
{
e.printStackTrace();
} finally
{
if (state != null)
{
if (_solutionDir != null)
{
_solutionDir.mkdirs();
_solver.writeState(model, state, _solutionDir);
solvedDir = _solutionDir.getAbsolutePath();
} else
{
solvedDir = _solver.writeState(model, state);
}
}
if (simState != null)
{
// Create a writer to write to the solution directory
final NumericsWriterFactory writerFactory = _container.getComponent(NumericsWriterFactory.class);
final NumericsWriter writer = writerFactory.create(new File(solvedDir, "ergodicDist"));
try
{
simState.write(writer);
} catch (Exception e)
{
LOG.log(Level.SEVERE, "Error writing ergodic distribution", e);
} finally
{
try
{
writer.close();
} catch (IOException e)
{
}
}
}
}
return solvedDir;
}
private Pair findNASteadyState(M model_, final Class extends M> modelClass_, final C config_)
throws SecurityException
{
Timer timer = new Timer();
Stoppable stoppable = timer.start("solveNA");
final AggregateFixedPointState state = new AggregateFixedPointState(null, null, null);
final FixedPointValueDelegate>> delegate = model_.getFixedPointDelegate();
final boolean indRisk = config_.getIndividualExogenousStates().get(0).numberOfElements() > 1;
System.out.println("Finding Solution with consistent aggregates\n");
final File stateFile = _solutionDir;
stateFile.mkdirs();
setWriterFactory(model_.getWriterFactory());
int calcCount[] = new int[] {0};
final MultiVariateVectorFunction fn = new MultiVariateVectorFunction()
{
@Override
public Double[] call(Double... args_)
{
try
{
/*
* Solve the model under the current calibration
*/
// Set the inputs on the delegate (which will pass them to
// the config
delegate.setInputs(ArrayUtils.toPrimitive(args_));
// Instantiate the model
M model = instantiateModel(modelClass_, config_);
if( calcCount[0] > 0) {
// Re-initialise for new config
model.initialise();
}
S calcState = model.initialState();
// Calculate a solution for the current config
determineSolution(model, calcState, false);
_solver.writeState(model, calcState, stateFile);
/*
* Find the Steady-State distribution of the model under the
* current calibration
*/
DiscretisedDistribution simState;
// if(calcCount[0] > 0) {
// try
// {
// simState = new DiscretisedDistribution(new File(_solutionDir, "lastSS"));
// } catch (IOException e)
// {
// LOG.warning("Unable to read sim state");
// simState = (DiscretisedDistribution) config_.getInitialSimState();
// }
// }
// else {
simState = (DiscretisedDistribution) config_.getInitialSimState();
// }
// NOTE: If there are no individual shocks the initial state
// should be the steady state
if (indRisk)
{
simState = _simulator.findErgodicDist(model, calcState, simState);
try
{
simState.write(getNumericsWriter(new File(_solutionDir, "lastSS")));
} catch (IOException e)
{
LOG.warning("Unable to write sim state");
}
}
state.setSimState(simState);
state.setCalcState(calcState);
state.setModel(model);
calcCount[0]++;
return ArrayUtils.toObject(delegate.getOutputs(state));
} catch (SolverException e)
{
throw new RuntimeException(e);
}
}
};
double[] initialInputs = delegate.getInitialInputs();
// If there are no inputs, there are not outputs, and no fixed point
// needs to be found - i.e. the steady
// state is found in one pass
if (initialInputs.length == 0)
{
fn.call(new Double[0]);
} else
{
MultivariateBoundedFPFinder fpFinder;
fpFinder = new MultivariateBoundedFPFinder(1e-6, 1e-6);
double[][] bounds = delegate.getBounds();
fpFinder.setBounds(bounds);
fpFinder.findFixedPoint(fn, initialInputs);
// Notify the delegate that the solution was found
delegate.solutionFound(state);
_solver.writeState(state.getModel(), state.getCalcState(), stateFile);
}
stoppable.stop();
return new Pair((S) state.getCalcState(), state.getSimState());
}
private S determineSolution(M modelInstance_, boolean writeState_)
{
S state = null;
try
{
if (_continueDir != null)
{
state = _solver.readState(modelInstance_, _continueDir);
} else
{
state = modelInstance_.initialState();
}
determineSolution(modelInstance_, state, writeState_);
} catch (Exception e)
{
LOG.log(Level.SEVERE,
String.format("Error instantiating %s", modelInstance_.getClass().getName()),e);
e.printStackTrace();
System.exit(-1);
}
return state;
}
private void determineSolution(M modelInstance_, S startingState_, boolean writeState_) throws SolverException
{
Stoppable timer = new Timer().start("Steady State");
_solver.solveModel(modelInstance_, startingState_, writeState_);
timer.stop();
}
@SuppressWarnings("unchecked")
protected M instantiateModel(Class extends M> modelClass_, C config_) throws SecurityException
{
// Only do this one, even if solving and simulating.
if (_modelInstance == null)
{
_container.addComponent(modelClass_);
if (config_ != null)
{
_container.addComponent(config_);
}
try
{
_modelInstance = _container.getComponent(modelClass_);
_modelInstance.initialise();
_container.addComponent(config_.getIndividualSolver());
Class extends AggregateProblemSolver>> aggregateProblemSolverClass = config_.getAggregateProblemSolver();
if(aggregateProblemSolverClass != null) {
_container.addComponent(aggregateProblemSolverClass);
}
_modelInstance.initIndividualSolverInstance((IndividualProblemSolver) _container.getComponent(config_.getIndividualSolver()));
if(aggregateProblemSolverClass != null)
{
// Because Pico does not support cyclical dependencies we have to do this
_modelInstance.initAggregateSolverInstance((AggregateProblemSolver) _container.getComponent(aggregateProblemSolverClass));
}
} catch (Exception e_)
{
LOG.log(Level.SEVERE, String.format("Error instantiating %s",
modelClass_.getName()), e_);
e_.printStackTrace();
System.exit(-1);
}
}
return _modelInstance;
}
private void simulateModel(Class extends M> modelClass_, C config_, String stateDir_, String resultsPath_, int periods_,
int burnIn_) throws ModelException, SimulatorException
{
// Instantiate a model from the named class and provided model
M model = instantiateModel(modelClass_, config_);
File stateDir = new File(stateDir_);
// Read the state from the specified state directory
S state = _solver.readState(model, stateDir);
if (model.getConfig().getIndividualExogenousStates().get(0).numberOfElements() > 1)
{
// Simulate the model given the state
SimulationResults, ?> results = simulateModel(model, state, periods_, burnIn_, stateDir, resultsPath_);
if (config_.hasAggUncertainty())
{
simAgg(model, stateDir, state, results, resultsPath_);
}
} else
{
simulateModel(model, state, periods_, burnIn_, stateDir, resultsPath_);
}
}
@SuppressWarnings("unchecked")
private void simAgg(M model, File stateDir, S state, SimulationResults, ?> results, String resultsPath_) throws ModelException,
SimulatorException
{
// Simulate using only the aggregate rules
_simulator.simAggregate((GenericBlockedArray) results.getShocks(), (PeriodAggregateState) results.getPeriod(0),
model, state, stateDir, resultsPath_);
}
@SuppressWarnings({ "rawtypes", "unchecked" })
private SimulationResults, ?> readSimResults(M model_, File dir_)
{
final NumericsReader reader = _simulator.getNumericsReader(dir_);
try
{
GenericBlockedArray shocks = (GenericBlockedArray) reader. getArray("shocks");
DoubleArray> states = (DoubleArray>) reader. getArray("states");
DoubleArray> controls = (DoubleArray>) reader. getArray("controls");
return new SimulationResults(shocks, states, controls);
} catch (IOException e)
{
throw new RuntimeException(e);
}
}
@SuppressWarnings({ "rawtypes", "unchecked" })
protected SimulationResults, ?> simulateModel(M model_, S state_, int periods_, int burnIn_, File stateDir_, String resultsPath_)
throws ModelException, SimulatorException
{
AggregateSimulationObserver observer = getSimObserver();
DiscretisedDistribution simState = (DiscretisedDistribution) model_.getConfig().getInitialSimState();
if (_saveDensities)
{
observer.addObserver(new DensitySavingSimulationObserver(new File(stateDir_, resultsPath_), _container
.getComponent(NumericsWriterFactory.class)));
}
// if(_solver.hasAggUncertainty(model_))
{
/*
* First, simulate the model using the individual solution
* previously solved
*/
SimulationResults, ?> simResults = _simulator.simulate(periods_, burnIn_, simState, model_.getConfig().getInitialExogenousStates(), model_, state_,
observer, stateDir_, resultsPath_);
return simResults;
}
// else
// {
// _simulator.findErgodicDist(model_, state_, simState);
//
// return null;
// }
}
/**
* @return
*/
@SuppressWarnings("rawtypes")
private AggregateSimulationObserver getSimObserver()
{
AggregateSimulationObserver aggObserver = _container.getComponent(AggregateSimulationObserver.class);
return aggObserver;
}
private S readStateFromDir(File directory_, M model_)
{
return _solver.readState(model_, directory_);
}
protected void runConfig(C config_, CommandLine commandLine_) throws ModelException, SimulatorException
{
boolean simulate = commandLine_.hasOption(SIMULATE_OPTION);
if (commandLine_.hasOption(CONTINUE_OPTION))
{
_continueDir = new File(commandLine_.getOptionValue(CONTINUE_OPTION));
}
if (commandLine_.hasOption(SOLUTION_DIR_OPTION))
{
_solutionDir = new File(commandLine_.getOptionValue(SOLUTION_DIR_OPTION));
} else
{
_solutionDir = _solver.createSolutionDirectory();
}
((SettableModelConfig) config_).setSolutionDirectory(_solutionDir);
if (commandLine_.hasOption(SOLVE_OPTION))
{
_continueDir = new File(solveModel(getModelClass(), config_));
}
String simResultsPath = null;
if (commandLine_.hasOption(SIM_DIR_OPTION))
{
simResultsPath = commandLine_.getOptionValue(SIM_DIR_OPTION);
}
if (simulate)
{
if (commandLine_.hasOption(SHOCK_DIR_OPTION))
{
if (commandLine_.hasOption(AGG_SIM_OPTION))
{
M model = instantiateModel(getModelClass(), config_);
simAgg(model, _continueDir, _solver.readState(model, _continueDir),
readSimResults(model, new File(commandLine_.getOptionValue(SHOCK_DIR_OPTION))), simResultsPath);
} else
{
simulateModel(getModelClass(), config_, new File(commandLine_.getOptionValue(SHOCK_DIR_OPTION)), _continueDir.getAbsolutePath(),
simResultsPath);
}
} else
{
simulateModel(getModelClass(), config_, _continueDir.getAbsolutePath(), simResultsPath,
Integer.valueOf(commandLine_.getOptionValue(PERIODS_OPTION, "10000")), 1000);
}
}
}
/**
* Indicates which model class is to be solved
*
* @return The model class to be used
*/
abstract protected Class extends M> getModelClass();
/**
* Creates the config object for this execution given the command line passed
*
* @param commandLine_ The command line used to execute the runner
*
* @return The appropriately initialised config object
*/
abstract protected C createConfig(CommandLine commandLine_);
/**
* Runs the mode based on the provided command line arguments
*
* @param args_ The arguments to be used
*
* @throws ModelException If there is an error caused by the model
* @throws SimulatorException If there is an error during simulation
*/
public void run(String[] args_) throws ModelException, SimulatorException
{
try
{
Options options = createOptions();
try
{
CommandLine commandLine = new BasicParser().parse(options, args_);
run(commandLine);
} catch (ParseException e)
{
HelpFormatter formatter = new HelpFormatter();
formatter.printHelp("", options);
System.exit(-1);
}
} catch (Exception e)
{
e.printStackTrace();
} finally
{
cleanUp();
}
}
void cleanUp()
{
getNumerics().destroy();
}
/**
* Creates the options object needed to parse the command line. Subclasses
* can override to extend the options but should leave existing options
* untouched.
*
* @return The options object to be used to parse the command line
*/
protected Options createOptions()
{
Options options = new Options();
options.addOption(SIMULATE_OPTION, false, "Indicates that the model should be simulated. ");
options.addOption(SOLVE_OPTION, false, "Indicates that the model should be solved. ");
options.addOption(CONTINUE_OPTION, true, "The directory from which state should be read to continue solving or simulate");
options.addOption(SOLUTION_DIR_OPTION, true, "The directory to which state should be written upon solution");
options.addOption(SIM_DIR_OPTION, true, "The directory to which the simulation results should be writtem. The "
+ "corresponding aggregate results will be written to written to a directory named with '_Agg' appended.");
options.addOption(SHOCK_DIR_OPTION, true, "The directory from which shocks should be read");
options.addOption(ADD_CLASS, true, "Class(es) to be added to the container. ");
options.addOption(PERIODS_OPTION, true, "The number of periods to simulate. ");
options.addOption(AGG_SIM_OPTION, false, "Only perform aggregate simulation. ");
options.addOption(SAVE_DENSITIES, false, "Causes the distributions at each simulation step to be saved. ");
options.addOption(WRITE_MATLAB, false, "Write state in a .mat file, rather than CSVs");
return options;
}
/**
* Runs the mode based on the provided command line arguments
*
* @param commandLine_ The arguments to be used
*
* @throws ModelException If there is an error caused by the model
* @throws SimulatorException If there is an error during simulation
*/
public void run(CommandLine commandLine_) throws ModelException, SimulatorException
{
if (commandLine_.hasOption(ADD_CLASS))
{
String[] classNames = commandLine_.getOptionValues(ADD_CLASS);
try
{
for (String className : classNames)
{
_container.addComponent(Class.forName(className));
}
} catch (ClassNotFoundException e)
{
LOG.log(Level.SEVERE, "Error adding extra classes", e);
System.exit(-1);
}
}
_saveDensities = commandLine_.hasOption(SAVE_DENSITIES);
System.setOut(new PrintStream(System.out)
{
// You'd want to override other methods too, of course.
@Override
public void println(String line)
{
StackTraceElement[] stack = Thread.currentThread().getStackTrace();
// Element 0 is getStackTrace
// Element 1 is println
// Element 2 is the caller
StackTraceElement caller = stack[2];
Logger.getAnonymousLogger().logp(Level.INFO, caller.getClassName(), caller.getMethodName(), line);
super.println(line);
}
});
// Set the solver and simulator to use MATLAB output if that option has
// been set
if (commandLine_.hasOption(WRITE_MATLAB))
{
_container.addComponent(MatlabWriterFactory.class);
} else
{
_container.addComponent(CSVWriterFactory.class);
}
setWriterFactory(_container.getComponent(NumericsWriterFactory.class));
_solver = _container.getComponent(Solver.class);
_simulator = _container.getComponent(DiscretisedDistributionSimulatorImpl.class);
runConfig(createConfig(commandLine_), commandLine_);
}
@SuppressWarnings("unchecked")
private void simulateModel(Class extends M> class_, C config_, File shocksDir_, String absolutePath_, String simResultsPath_) throws SimulatorException, ModelException
{
M model = instantiateModel(class_, config_);
File stateDir = new File(absolutePath_);
S calcState = readStateFromDir(stateDir, model);
_simulator.simulateShocks((DiscretisedDistribution)config_.getInitialSimState(), ((SimulationResults, Integer>) readSimResults(model, shocksDir_)).getShocks(),
model, calcState, SimulationObserver.silent(), stateDir, simResultsPath_);
}
/**
* Currently, only two simulators are supported, determining the types of the input pair accepted:
*
* - A heterogeneous agent simulator (use {@link DiscretisedDistribution}) with a single {@link Integer} shock.
*
- A representative agent simulator (use {@link RepresentativeAgentSimState}) with a single {@link Double} shock.
*
* Other combinations will cause an {@link IllegalArgumentException}.
*
* TODO: V should not extend number, e.g. for multi-exo-state models
*
* @param startingState_ The initial sim state and exogenous states for the simulation
*
* @return A simulator appropriate to the provided initial state
*
* @param The type holding simulation state
* @param The numeric type of aggregate shocks
*/
@SuppressWarnings("unchecked")
protected Simulator
getSimulator(Pair startingState_)
{
if(startingState_.getLeft() instanceof DiscretisedDistribution &&
startingState_.getRight() instanceof Integer)
{
return (Simulator) _container.getComponent(
DiscretisedDistributionSimulatorImpl.class);
}
if(startingState_.getLeft() instanceof RepresentativeAgentSimState &&
startingState_.getRight() instanceof Double)
{
return (Simulator) _container.getComponent(RASimulator.class);
}
throw new IllegalArgumentException(
String.format(" The combination of endogenous state type "
+ "'%s' and exogenous state type '%s' "
+ "is not yet supported!",
startingState_.getLeft().getClass(),
startingState_.getRight().getClass()));
}
/**
* Provides access to the simulator instance used by this ModelRunner
*
* @return The simulator instance
*/
protected DiscretisedDistributionSimulatorImpl getSimulator()
{
return _simulator;
}
/**
* Returns the Solver used by this instance
*
* @return The Solver
*/
protected Solver getSolver()
{
return _solver;
}
}