All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.meliorbis.economics.model.ModelRunner Maven / Gradle / Ivy

Go to download

A library for solving economic models, particularly macroeconomic models with heterogeneous agents who have model-consistent expectations

There is a newer version: 1.1
Show newest version
package com.meliorbis.economics.model;

import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import java.util.logging.Level;
import java.util.logging.Logger;

import org.apache.commons.cli.BasicParser;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.lang.ArrayUtils;
import org.picocontainer.MutablePicoContainer;
import org.picocontainer.PicoBuilder;
import org.picocontainer.injectors.MultiInjection;

import com.meliorbis.economics.aggregate.AggregateProblemSolver;
import com.meliorbis.economics.individual.IndividualProblemSolver;
import com.meliorbis.economics.infrastructure.AbstractModel;
import com.meliorbis.economics.infrastructure.Base;
import com.meliorbis.economics.infrastructure.Solver;
import com.meliorbis.economics.infrastructure.SolverException;
import com.meliorbis.economics.infrastructure.simulation.AggregateSimulationObserver;
import com.meliorbis.economics.infrastructure.simulation.DensitySavingSimulationObserver;
import com.meliorbis.economics.infrastructure.simulation.DiscretisedDistribution;
import com.meliorbis.economics.infrastructure.simulation.DiscretisedDistributionSimulatorImpl;
import com.meliorbis.economics.infrastructure.simulation.PeriodAggregateState;
import com.meliorbis.economics.infrastructure.simulation.RASimulator;
import com.meliorbis.economics.infrastructure.simulation.RepresentativeAgentSimState;
import com.meliorbis.economics.infrastructure.simulation.SimState;
import com.meliorbis.economics.infrastructure.simulation.SimulationObserver;
import com.meliorbis.economics.infrastructure.simulation.SimulationResults;
import com.meliorbis.economics.infrastructure.simulation.Simulator;
import com.meliorbis.economics.infrastructure.simulation.SimulatorException;
import com.meliorbis.numerics.fixedpoint.FixedPointValueDelegate;
import com.meliorbis.numerics.fixedpoint.MultivariateBoundedFPFinder;
import com.meliorbis.numerics.function.MultiVariateVectorFunction;
import com.meliorbis.numerics.generic.impl.GenericBlockedArray;
import com.meliorbis.numerics.generic.primitives.DoubleArray;
import com.meliorbis.numerics.generic.primitives.impl.DoubleArrayFunctions;
import com.meliorbis.numerics.io.NumericsReader;
import com.meliorbis.numerics.io.NumericsWriter;
import com.meliorbis.numerics.io.NumericsWriterFactory;
import com.meliorbis.numerics.io.csv.CSVWriterFactory;
import com.meliorbis.numerics.io.matlab.MatlabWriterFactory;
import com.meliorbis.utils.Pair;
import com.meliorbis.utils.Timer;
import com.meliorbis.utils.Timer.Stoppable;
import com.meliorbis.utils.Utils;

/**
 * The abstract base class for objects that control configuration, solution and simulation of models
 * 
 * @author Tobias Grasl
 *
 * @param  The class of Model this runner is for
 * @param  The configuration class this runner is for
 * @param  The state class this runner is for
 */
public abstract class ModelRunner, C extends ModelConfig, S extends State> extends Base
{
	private static final String PERIODS_OPTION = "periods";
	private static final String SOLUTION_DIR_OPTION = "stateDir";
	private static final String SIM_DIR_OPTION = "simDir";
	private static final String CONTINUE_OPTION = "initialState";
	private static final String SHOCK_DIR_OPTION = "useShocks";
	private static final String SIMULATE_OPTION = "simulate";
	private static final String SOLVE_OPTION = "solve";
	private static final String ADD_CLASS = "addClass";
	private static final String AGG_SIM_OPTION = "aggOnly";
	private static final String SAVE_DENSITIES = "saveSimDensities";
	private static final String WRITE_MATLAB = "writeMatlab";

	private static final Logger LOG = Logger.getLogger(ModelRunner.class.getName());
	private Solver _solver;
	final private MutablePicoContainer _container;
	private DiscretisedDistributionSimulatorImpl _simulator;
	private M _modelInstance = null;

	private File _continueDir;
	private File _solutionDir;
	private boolean _saveDensities;

	public ModelRunner()
	{
		_container = new PicoBuilder().withCaching().withComponentFactory(new MultiInjection()).build();;
		
		_container.addComponent(getNumerics());
		_container.addComponent(Solver.class);
		_container.addComponent(DiscretisedDistributionSimulatorImpl.class);
		_container.addComponent(RASimulator.class);
		_container.addComponent(AggregateSimulationObserver.class);
	}
	
	/**
	 * Sets the path from which initial state will be read
	 * 
	 * @param path_ The path from which to read starting state
	 */
	protected void setInitialStatePath(String path_)
	{
		_continueDir = new File(path_);
	}
	
	public String solveModel(Class modelClass_, C config_)
	{

		S state = null;
		String solvedDir = null;

		DiscretisedDistribution simState = null;

		M model = null;
		try
		{

			model = instantiateModel(modelClass_, config_);
			
			
			if (!config_.hasAggUncertainty())
			{
				if(!config_.hasIndUncertainty()) 
				{
					state = determineSolution(model, false);

					System.out.println("Found consistent solution");
					
					DoubleArray aggregateTransition = state.getAggregateTransition();
					
					int[] at = Utils.repeatArray(0, aggregateTransition.numberOfDimensions()-1);
					
					for (int i = 0; i < config_.getAggregateEndogenousStateCount(); i++)
					{
						at[i] = -1;
					}
					
					DoubleArray reducedTrans = aggregateTransition.at(at);
					
					DoubleArray fp = DoubleArrayFunctions.interpolateFixedPoint(reducedTrans, (DoubleArray[])config_.getAggregateEndogenousStates().toArray());
					
					System.out.println("Fixed Point: " + fp);
				}
				else
				{
					System.out.println("Model has no aggregate risk. Finding Steady State");
	
					Pair results = findNASteadyState(model, modelClass_, config_);
	
					state = results.getLeft();
					simState = results.getRight();
				}
			} else
			{
				state = determineSolution(model, false);
			}
			
		} catch (Throwable e)
		{
			e.printStackTrace();
		} finally
		{
			if (state != null)
			{
				if (_solutionDir != null)
				{
					_solutionDir.mkdirs();
					_solver.writeState(model, state, _solutionDir);
					solvedDir = _solutionDir.getAbsolutePath();
				} else
				{
					solvedDir = _solver.writeState(model, state);
				}
			}

			if (simState != null)
			{
				// Create a writer to write to the solution directory
				final NumericsWriterFactory writerFactory = _container.getComponent(NumericsWriterFactory.class);
				final NumericsWriter writer = writerFactory.create(new File(solvedDir, "ergodicDist"));

				try
				{
					simState.write(writer);
				} catch (Exception e)
				{
					LOG.log(Level.SEVERE, "Error writing ergodic distribution", e);
				} finally
				{
					try
					{
						writer.close();
					} catch (IOException e)
					{
					}

				}

			}

		}

		return solvedDir;
	}

	private Pair findNASteadyState(M model_, final Class modelClass_, final C config_)
			throws SecurityException
	{
		Timer timer = new Timer();

		Stoppable stoppable = timer.start("solveNA");

		final AggregateFixedPointState state = new AggregateFixedPointState(null, null, null);

		final FixedPointValueDelegate>> delegate = model_.getFixedPointDelegate();

		final boolean indRisk = config_.getIndividualExogenousStates().get(0).numberOfElements() > 1;

		System.out.println("Finding Solution with consistent aggregates\n");

		final File stateFile = _solutionDir;
		stateFile.mkdirs();
		
		setWriterFactory(model_.getWriterFactory());
		int calcCount[] = new int[] {0};
		
		final MultiVariateVectorFunction fn = new MultiVariateVectorFunction()
		{

			@Override
			public Double[] call(Double... args_)
			{
				try
				{
					/*
					 * Solve the model under the current calibration
					 */
					// Set the inputs on the delegate (which will pass them to
					// the config
					delegate.setInputs(ArrayUtils.toPrimitive(args_));

					// Instantiate the model
					M model = instantiateModel(modelClass_, config_);

					if( calcCount[0] > 0) {
						// Re-initialise for new config
						model.initialise();
					}
					
					
					S calcState = model.initialState();
					
					
					// Calculate a solution for the current config
					determineSolution(model, calcState, false);

					_solver.writeState(model, calcState, stateFile);

					/*
					 * Find the Steady-State distribution of the model under the
					 * current calibration
					 */
					DiscretisedDistribution simState;

//					if(calcCount[0] > 0) {
//						try
//						{
//							simState = new DiscretisedDistribution(new File(_solutionDir, "lastSS"));
//						} catch (IOException e)
//						{
//							LOG.warning("Unable to read sim state");
//							simState = (DiscretisedDistribution) config_.getInitialSimState();
//						}
//					}
//					else {
						simState = (DiscretisedDistribution) config_.getInitialSimState();
				//	}	
					
					// NOTE: If there are no individual shocks the initial state
					// should be the steady state
					if (indRisk)
					{
						simState = _simulator.findErgodicDist(model, calcState, simState);
						
						try
						{
							simState.write(getNumericsWriter(new File(_solutionDir, "lastSS")));
						} catch (IOException e)
						{
							LOG.warning("Unable to write sim state");
						}
					}

					state.setSimState(simState);
					state.setCalcState(calcState);
					state.setModel(model);

					calcCount[0]++;
					
					return ArrayUtils.toObject(delegate.getOutputs(state));
				} catch (SolverException e)
				{
					throw new RuntimeException(e);
				}
			}
		};

		double[] initialInputs = delegate.getInitialInputs();

		// If there are no inputs, there are not outputs, and no fixed point
		// needs to be found - i.e. the steady
		// state is found in one pass
		if (initialInputs.length == 0)
		{
			fn.call(new Double[0]);
		} else
		{
			MultivariateBoundedFPFinder fpFinder;

			fpFinder = new MultivariateBoundedFPFinder(1e-6, 1e-6);

			double[][] bounds = delegate.getBounds();
			
			fpFinder.setBounds(bounds);
			
			fpFinder.findFixedPoint(fn, initialInputs);
			
			// Notify the delegate that the solution was found
			delegate.solutionFound(state);
			
			_solver.writeState(state.getModel(), state.getCalcState(), stateFile);
		}

		stoppable.stop();

		return new Pair((S) state.getCalcState(), state.getSimState());
	}

	private S determineSolution(M modelInstance_, boolean writeState_)
	{
		S state = null;
		try
		{
			if (_continueDir != null)
			{
				state = _solver.readState(modelInstance_, _continueDir);
			} else
			{
				state = modelInstance_.initialState();
			}

			determineSolution(modelInstance_, state, writeState_);

		} catch (Exception e)
		{
			LOG.log(Level.SEVERE, 
					String.format("Error instantiating %s", modelInstance_.getClass().getName()),e);
			e.printStackTrace();
			System.exit(-1);
		}
		return state;
	}

	private void determineSolution(M modelInstance_, S startingState_, boolean writeState_) throws SolverException
	{
		Stoppable timer = new Timer().start("Steady State");

		_solver.solveModel(modelInstance_, startingState_, writeState_);
		
		timer.stop();
	}

	@SuppressWarnings("unchecked")
	protected M instantiateModel(Class modelClass_, C config_) throws SecurityException
	{
		// Only do this one, even if solving and simulating.
		if (_modelInstance == null)
		{
			_container.addComponent(modelClass_);

			if (config_ != null)
			{
				_container.addComponent(config_);
			}

			try
			{
				_modelInstance = _container.getComponent(modelClass_);
				_modelInstance.initialise();

				_container.addComponent(config_.getIndividualSolver());
				Class> aggregateProblemSolverClass = config_.getAggregateProblemSolver();
				
				if(aggregateProblemSolverClass != null) {
					_container.addComponent(aggregateProblemSolverClass);
				}
				
				_modelInstance.initIndividualSolverInstance((IndividualProblemSolver) _container.getComponent(config_.getIndividualSolver()));
				
				if(aggregateProblemSolverClass != null)
				{
					// Because Pico does not support cyclical dependencies  we have to do this
					_modelInstance.initAggregateSolverInstance((AggregateProblemSolver) _container.getComponent(aggregateProblemSolverClass));
				}
			} catch (Exception e_)
			{
				LOG.log(Level.SEVERE, String.format("Error instantiating %s", 
						modelClass_.getName()), e_);
				e_.printStackTrace();
				System.exit(-1);
			}
		}

		return _modelInstance;
	}

	private void simulateModel(Class modelClass_, C config_, String stateDir_, String resultsPath_, int periods_,
			int burnIn_) throws ModelException, SimulatorException
	{
		// Instantiate a model from the named class and provided model
		M model = instantiateModel(modelClass_, config_);

		File stateDir = new File(stateDir_);

		// Read the state from the specified state directory
		S state = _solver.readState(model, stateDir);

		if (model.getConfig().getIndividualExogenousStates().get(0).numberOfElements() > 1)
		{
			// Simulate the model given the state
			SimulationResults results = simulateModel(model, state, periods_, burnIn_, stateDir, resultsPath_);

			if (config_.hasAggUncertainty())
			{
				simAgg(model, stateDir, state, results, resultsPath_);
			}
		} else
		{
			simulateModel(model, state, periods_, burnIn_, stateDir, resultsPath_);
		}
	}

	
	@SuppressWarnings("unchecked")
	private void simAgg(M model, File stateDir, S state, SimulationResults results, String resultsPath_) throws ModelException,
			SimulatorException
	{
		// Simulate using only the aggregate rules
		_simulator.simAggregate((GenericBlockedArray) results.getShocks(), (PeriodAggregateState) results.getPeriod(0),
				model, state, stateDir, resultsPath_);
	}

	@SuppressWarnings({ "rawtypes", "unchecked" })
	private SimulationResults readSimResults(M model_, File dir_)
	{
		final NumericsReader reader = _simulator.getNumericsReader(dir_);

		try
		{
			GenericBlockedArray shocks = (GenericBlockedArray) reader. getArray("shocks");
			DoubleArray states = (DoubleArray) reader. getArray("states");

			DoubleArray controls = (DoubleArray) reader. getArray("controls");

			return new SimulationResults(shocks, states, controls);

		} catch (IOException e)
		{
			throw new RuntimeException(e);
		}

	}

	@SuppressWarnings({ "rawtypes", "unchecked" })
	protected SimulationResults simulateModel(M model_, S state_, int periods_, int burnIn_, File stateDir_, String resultsPath_)
			throws ModelException, SimulatorException
	{
		AggregateSimulationObserver observer = getSimObserver();

		DiscretisedDistribution simState = (DiscretisedDistribution) model_.getConfig().getInitialSimState();

		if (_saveDensities)
		{
			observer.addObserver(new DensitySavingSimulationObserver(new File(stateDir_, resultsPath_), _container
					.getComponent(NumericsWriterFactory.class)));
		}

		// if(_solver.hasAggUncertainty(model_))
		{
			/*
			 * First, simulate the model using the individual solution
			 * previously solved
			 */
			SimulationResults simResults = _simulator.simulate(periods_, burnIn_, simState, model_.getConfig().getInitialExogenousStates(), model_, state_,
					observer, stateDir_, resultsPath_);

			return simResults;
		}
		// else
		// {
		// _simulator.findErgodicDist(model_, state_, simState);
		//
		// return null;
		// }
	}

	/**
	 * @return
	 */
	@SuppressWarnings("rawtypes")
	private AggregateSimulationObserver getSimObserver()
	{
		AggregateSimulationObserver aggObserver = _container.getComponent(AggregateSimulationObserver.class);

		return aggObserver;
	}

	private S readStateFromDir(File directory_, M model_)
	{
		return _solver.readState(model_, directory_);
	}

	protected void runConfig(C config_, CommandLine commandLine_) throws ModelException, SimulatorException
	{
		boolean simulate = commandLine_.hasOption(SIMULATE_OPTION);

		if (commandLine_.hasOption(CONTINUE_OPTION))
		{
			_continueDir = new File(commandLine_.getOptionValue(CONTINUE_OPTION));
		}

		if (commandLine_.hasOption(SOLUTION_DIR_OPTION))
		{
			_solutionDir = new File(commandLine_.getOptionValue(SOLUTION_DIR_OPTION));
		} else
		{
			_solutionDir = _solver.createSolutionDirectory();
		}

		((SettableModelConfig) config_).setSolutionDirectory(_solutionDir);

		if (commandLine_.hasOption(SOLVE_OPTION))
		{
			_continueDir = new File(solveModel(getModelClass(), config_));
		}

		String simResultsPath = null;
		if (commandLine_.hasOption(SIM_DIR_OPTION))
		{
			simResultsPath = commandLine_.getOptionValue(SIM_DIR_OPTION);
		}

		if (simulate)
		{
			if (commandLine_.hasOption(SHOCK_DIR_OPTION))
			{
				if (commandLine_.hasOption(AGG_SIM_OPTION))
				{
					M model = instantiateModel(getModelClass(), config_);

					simAgg(model, _continueDir, _solver.readState(model, _continueDir),
							readSimResults(model, new File(commandLine_.getOptionValue(SHOCK_DIR_OPTION))), simResultsPath);
				} else
				{
					simulateModel(getModelClass(), config_, new File(commandLine_.getOptionValue(SHOCK_DIR_OPTION)), _continueDir.getAbsolutePath(),
							simResultsPath);
				}

			} else
			{
				simulateModel(getModelClass(), config_, _continueDir.getAbsolutePath(), simResultsPath,
						Integer.valueOf(commandLine_.getOptionValue(PERIODS_OPTION, "10000")), 1000);
			}
		}
	}

	/**
	 * Indicates which model class is to be solved
	 * 
	 * @return The model class to be used
	 */
	abstract protected Class getModelClass();

	/**
	 * Creates the config object for this execution given the command line passed
	 * 
	 * @param commandLine_ The command line used to execute the runner
	 * 
	 * @return The appropriately initialised config object
	 */
	abstract protected C createConfig(CommandLine commandLine_);

	/**
	 * Runs the mode based on the provided command line arguments
	 * 
	 * @param args_ The arguments to be used
	 * 
	 * @throws ModelException If there is an error caused by the model
	 * @throws SimulatorException If there is an error during simulation
	 */
	public void run(String[] args_) throws ModelException, SimulatorException
	{
		try
		{
			Options options = createOptions();

			try
			{
				CommandLine commandLine = new BasicParser().parse(options, args_);

				run(commandLine);

			} catch (ParseException e)
			{
				HelpFormatter formatter = new HelpFormatter();
				formatter.printHelp("", options);
				System.exit(-1);
			}
		} catch (Exception e)
		{
			e.printStackTrace();
		} finally
		{
			cleanUp();
		}
	}
	
	void cleanUp()
	{
		getNumerics().destroy();
	}

	/**
	 * Creates the options object needed to parse the command line. Subclasses
	 * can override to extend the options but should leave existing options
	 * untouched.
	 *
	 * @return The options object to be used to parse the command line
	 */
	protected Options createOptions()
	{
		Options options = new Options();

		options.addOption(SIMULATE_OPTION, false, "Indicates that the model should be simulated. ");
		options.addOption(SOLVE_OPTION, false, "Indicates that the model should be solved. ");

		options.addOption(CONTINUE_OPTION, true, "The directory from which state should be read to continue solving or simulate");
		options.addOption(SOLUTION_DIR_OPTION, true, "The directory to which state should be written upon solution");
		options.addOption(SIM_DIR_OPTION, true, "The directory to which the simulation results should be writtem. The "
				+ "corresponding aggregate results will be written to written to a directory named with '_Agg' appended.");
		options.addOption(SHOCK_DIR_OPTION, true, "The directory from which shocks should be read");
		options.addOption(ADD_CLASS, true, "Class(es) to be added to the container. ");
		options.addOption(PERIODS_OPTION, true, "The number of periods to simulate. ");
		options.addOption(AGG_SIM_OPTION, false, "Only perform aggregate simulation. ");
		options.addOption(SAVE_DENSITIES, false, "Causes the distributions at each simulation step to be saved. ");
		options.addOption(WRITE_MATLAB, false, "Write state in a .mat file, rather than CSVs");

		return options;
	}

	/**
	 * Runs the mode based on the provided command line arguments
	 * 
	 * @param commandLine_ The arguments to be used
	 * 
	 * @throws ModelException If there is an error caused by the model
	 * @throws SimulatorException If there is an error during simulation
	 */
	public void run(CommandLine commandLine_) throws ModelException, SimulatorException
	{
		if (commandLine_.hasOption(ADD_CLASS))
		{
			String[] classNames = commandLine_.getOptionValues(ADD_CLASS);

			try
			{
				for (String className : classNames)
				{
					_container.addComponent(Class.forName(className));
				}
			} catch (ClassNotFoundException e)
			{
				LOG.log(Level.SEVERE, "Error adding extra classes", e);
				System.exit(-1);
			}
		}

		_saveDensities = commandLine_.hasOption(SAVE_DENSITIES);

		System.setOut(new PrintStream(System.out)
		{
			// You'd want to override other methods too, of course.
			@Override
			public void println(String line)
			{
				StackTraceElement[] stack = Thread.currentThread().getStackTrace();
				// Element 0 is getStackTrace
				// Element 1 is println
				// Element 2 is the caller
				StackTraceElement caller = stack[2];
				Logger.getAnonymousLogger().logp(Level.INFO, caller.getClassName(), caller.getMethodName(), line);

				super.println(line);
			}
		});
		
		// Set the solver and simulator to use MATLAB output if that option has
		// been set
		if (commandLine_.hasOption(WRITE_MATLAB))
		{
			_container.addComponent(MatlabWriterFactory.class);
		} else
		{
			_container.addComponent(CSVWriterFactory.class);
		}

		setWriterFactory(_container.getComponent(NumericsWriterFactory.class));

		_solver = _container.getComponent(Solver.class);
		_simulator = _container.getComponent(DiscretisedDistributionSimulatorImpl.class);

		runConfig(createConfig(commandLine_), commandLine_);
	}

	
	@SuppressWarnings("unchecked")
	private void simulateModel(Class class_, C config_, File shocksDir_, String absolutePath_, String simResultsPath_) throws SimulatorException, ModelException
	{
		M model = instantiateModel(class_, config_);
		File stateDir = new File(absolutePath_);
		S calcState = readStateFromDir(stateDir, model);
		
		_simulator.simulateShocks((DiscretisedDistribution)config_.getInitialSimState(), ((SimulationResults) readSimResults(model, shocksDir_)).getShocks(),
				model, calcState, SimulationObserver.silent(), stateDir, simResultsPath_);
	}
	
	/**
	 * Currently, only two simulators are supported, determining the types of the input pair accepted:
	 * 
    *
  1. A heterogeneous agent simulator (use {@link DiscretisedDistribution}) with a single {@link Integer} shock. *
  2. A representative agent simulator (use {@link RepresentativeAgentSimState}) with a single {@link Double} shock. *
* Other combinations will cause an {@link IllegalArgumentException}. *

* TODO: V should not extend number, e.g. for multi-exo-state models * * @param startingState_ The initial sim state and exogenous states for the simulation * * @return A simulator appropriate to the provided initial state * * @param The type holding simulation state * @param The numeric type of aggregate shocks */ @SuppressWarnings("unchecked") protected Simulator getSimulator(Pair startingState_) { if(startingState_.getLeft() instanceof DiscretisedDistribution && startingState_.getRight() instanceof Integer) { return (Simulator) _container.getComponent( DiscretisedDistributionSimulatorImpl.class); } if(startingState_.getLeft() instanceof RepresentativeAgentSimState && startingState_.getRight() instanceof Double) { return (Simulator) _container.getComponent(RASimulator.class); } throw new IllegalArgumentException( String.format(" The combination of endogenous state type " + "'%s' and exogenous state type '%s' " + "is not yet supported!", startingState_.getLeft().getClass(), startingState_.getRight().getClass())); } /** * Provides access to the simulator instance used by this ModelRunner * * @return The simulator instance */ protected DiscretisedDistributionSimulatorImpl getSimulator() { return _simulator; } /** * Returns the Solver used by this instance * * @return The Solver */ protected Solver getSolver() { return _solver; } }