All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.meliorbis.economics.infrastructure.simulation.DiscretisedDistributionSimulatorImpl Maven / Gradle / Ivy

Go to download

A library for solving economic models, particularly macroeconomic models with heterogeneous agents who have model-consistent expectations

There is a newer version: 1.1
Show newest version
/**
 *
 */
package com.meliorbis.economics.infrastructure.simulation;

import static com.meliorbis.numerics.DoubleArrayFactories.createArrayOfSize;
import static com.meliorbis.numerics.generic.primitives.impl.DoubleArrayFunctions.maximumRelativeDifferenceSpecial;
import static com.meliorbis.numerics.generic.primitives.impl.Interpolation.interp;
import static com.meliorbis.numerics.generic.primitives.impl.Interpolation.interpolateFunctionAcross;
import static com.meliorbis.numerics.generic.primitives.impl.Interpolation.spec;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;

import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.math3.util.Precision;

import com.meliorbis.economics.infrastructure.Solver;
import com.meliorbis.economics.model.Model;
import com.meliorbis.economics.model.ModelConfig;
import com.meliorbis.economics.model.ModelException;
import com.meliorbis.economics.model.ModelWithControls;
import com.meliorbis.economics.model.State;
import com.meliorbis.economics.model.StateWithControls;
import com.meliorbis.economics.model.lifecycle.ILifecycleModel;
import com.meliorbis.numerics.DoubleArrayFactories;
import com.meliorbis.numerics.Numerics;
import com.meliorbis.numerics.generic.MultiDimensionalArray;
import com.meliorbis.numerics.generic.impl.GenericBlockedArray;
import com.meliorbis.numerics.generic.impl.IntegerArray;
import com.meliorbis.numerics.generic.primitives.DoubleArray;
import com.meliorbis.numerics.generic.primitives.DoubleBinaryOp;
import com.meliorbis.numerics.generic.primitives.DoubleNaryOp;
import com.meliorbis.numerics.generic.primitives.impl.DoubleArrayFunctions;
import com.meliorbis.numerics.generic.primitives.impl.Interpolation.Params;
import com.meliorbis.numerics.generic.primitives.impl.Interpolation.Specification;
import com.meliorbis.numerics.io.NumericsWriter;
import com.meliorbis.numerics.io.NumericsWriterFactory;
import com.meliorbis.numerics.threading.ComputableRecursiveAction;
import com.meliorbis.utils.Timer;
import com.meliorbis.utils.Timer.Stoppable;
import com.meliorbis.utils.Utils;

/**
 * Heterogeneous Agent Model Simulator
 *
 * @author Tobias Grasl
 */
public final class DiscretisedDistributionSimulatorImpl
		extends AbstractSimulator
		implements DiscretisedDistributionSimulator
{
	static final Logger LOG = Logger.getLogger(Solver.class.getName());

	public DiscretisedDistributionSimulatorImpl()
	{
		super();
	}

	public DiscretisedDistributionSimulatorImpl(
			NumericsWriterFactory outputFactory_)
	{
		super(outputFactory_);
	}

	public DiscretisedDistributionSimulatorImpl(Numerics numerics_,
			NumericsWriterFactory outputFactory_,
			AggregateSimulationObserver observer_)
	{
		super(outputFactory_, observer_);
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * com.meliorbis.economics.infrastructure.ISimulator#createShockSequence(int
	 * [], int, com.meliorbis.economics.model.IModel)
	 */
	@Override
	public MultiDimensionalArray createShockSequence(
			MultiDimensionalArray initialShockStates_, int periods_,
			Model model_)
	{
		return createShockSequence(initialShockStates_, periods_, model_, null);
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * com.meliorbis.economics.infrastructure.ISimulator#createShockSequence(int
	 * [], int, com.meliorbis.economics.model.IModel, java.lang.Integer)
	 */
	@Override
	public GenericBlockedArray createShockSequence(
			MultiDimensionalArray initialShockStates_, int periods_,
			Model model_, Integer seed_)
	{
		final int nAggExoStates = model_.getConfig()
				.getAggregateExogenousStateCount();
		final int nAggNormStates = model_.getConfig()
				.getAggregateNormalisingStateCount();
		final int nIndExoStates = model_.getConfig()
				.getIndividualExogenousStateCount();

		DoubleArray exoStateTransition = model_.getConfig()
				.getExogenousStateTransition();

		int[] indShockIndex = Utils.repeatArray(-1,
				exoStateTransition.numberOfDimensions());

		// Just select the first individual shock state as a source
		for (int i = 0; i < nIndExoStates; i++)
		{
			indShockIndex[i] = 0;
		}

		/*
		 * Take the sum over all possible individual future states at each
		 * aggregate transition point; this will give us the aggregate
		 * transition probability
		 */

		DoubleArray aggTransition = exoStateTransition.at(indShockIndex)
				.across(Utils.sequence(nAggExoStates,
						nAggExoStates + nIndExoStates))
				.sum();

		Random random;

		if (seed_ == null)
		{
			random = new Random();
		} else
		{
			random = new Random(seed_);
		}
		IntegerArray allShocks = getNumerics().newIntArray(periods_,
				nAggExoStates + nAggNormStates);

		int[] currentAggShock = ArrayUtils.toPrimitive(
				((GenericBlockedArray) initialShockStates_)
						.toArray());
		currentAggShock = Arrays.copyOf(currentAggShock, nAggExoStates);

		int[] allInitShocks = new int[nAggExoStates];

		System.arraycopy(
				ArrayUtils.toPrimitive(
						((GenericBlockedArray) initialShockStates_)
								.toArray()),
				0, allInitShocks, 0, nAggExoStates);

		allShocks.at(0).fill(ArrayUtils.toObject(
				Arrays.copyOf(allInitShocks, nAggExoStates + nAggNormStates)));

		int shockPeriod = 1;

		while (shockPeriod < periods_)
		{

			/*
			 * Draw the next values for aggregate shocks at random
			 */
			int[] futureAggShock = Utils
					.drawRandomState(aggTransition.at(currentAggShock), random);

			allShocks.at(shockPeriod).fill(ArrayUtils.toObject(futureAggShock));

			currentAggShock = /* copy only the transient shocks */Arrays
					.copyOf(futureAggShock, nAggExoStates);
			shockPeriod++;
		}
		// allShocks.fillDimensions(ArrayUtils.toObject(allInitShocks), 1);
		return allShocks;
	}

	final class Forecaster
	{
		private final DoubleArray _expectations;
		private final Specification _dimSpecs[];
		private int[] _shockSelector;
		private int _priorShockCount;
		private int _currentShockCount;

		Forecaster(DoubleArray expectations_, ModelConfig config_)
		{
			_expectations = expectations_;

			final int nAggEndoStates = config_
					.getAggregateEndogenousStateCount();
			final int nAggExoStates = config_.getAggregateExogenousStateCount();
			final int nAggNormStates = config_
					.getAggregateNormalisingStateCount();

			_dimSpecs = new Specification[nAggEndoStates];

			for (int i = 0; i < _dimSpecs.length; i++)
			{
				_dimSpecs[i] = spec(i,
						config_.getAggregateEndogenousStates().get(i),
						Double.NaN);
			}

			_priorShockCount = nAggExoStates;
			_currentShockCount = _priorShockCount + nAggNormStates;

			_shockSelector = new int[_priorShockCount + _currentShockCount];
		}

		public DoubleArray forecast(IntegerArray currentShocks_,
				IntegerArray futureShocks_, DoubleArray currentStates_)
		{
			// Create an appropriate index to select across prior and current
			// shocks
			// IMPORTANT: Don't necessarily use all the prior shocks, because
			// some of them will be
			// permanent ones, and they are irrelevant in the prior period
			for (int i = 0; i < _priorShockCount; i++)
			{
				_shockSelector[i] = currentShocks_.get(i);
			}

			for (int i = 0; i < _currentShockCount; i++)
			{
				_shockSelector[i + _priorShockCount] = futureShocks_.get(i);
			}

			// Now set the current aggregate state on the interpolation objects
			for (int i = 0; i < _dimSpecs.length; i++)
			{
				_dimSpecs[i].target = currentStates_.get(i);
			}

			// And interpolate to the future aggregate state!
			return interp(_expectations.at(_shockSelector), _dimSpecs);
		}
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see com.meliorbis.economics.infrastructure.ISimulator#simAggregate(com.
	 * meliorbis.numerics.generic.impl.GenericBlockedArray,
	 * com.meliorbis.economics.infrastructure.simulation.PeriodAggregateState,
	 * com.meliorbis.economics.model.IModel, S, java.io.File, java.lang.String)
	 */
	@Override
	@SuppressWarnings(
	{ "rawtypes" })
	public > SimulationResults simAggregate(
			MultiDimensionalArray shocks_,
			PeriodAggregateState initialState_, Model model_,
			S state_, File outDir_, String resultsPath_)
			throws ModelException, SimulatorException
	{
		// Create a new results object of the appropriate size
		SimulationResults results = new SimulationResults();

		// Get the forecasting rules
		DoubleArray expectedAggs = state_.getExpectedAggregateStates();

		Forecaster forecaster = new Forecaster(expectedAggs,
				model_.getConfig());

		// If the model has control they also need to be forecast
		Forecaster controlForecaster = null;
		if (model_.getConfig().getAggregateControlCount() > 0)
		{
			controlForecaster = new Forecaster(
					((StateWithControls) state_).getExpectedAggregateControls(),
					model_.getConfig());
		}

		DoubleArray currentStates = initialState_.getStates();
		IntegerArray currentShocks = (IntegerArray) initialState_.getShocks();
		DoubleArray currentControls = initialState_.getControls();

		// Add the first period, which is identical to the input
		results.addPeriod(currentShocks, currentStates, currentControls);

		int period = 0;

		// Iterate over the shock sequence
		final int periods = shocks_.size()[0];

		while (++period < periods)
		{
			IntegerArray futureShocks = (IntegerArray) shocks_.at(period);

			currentStates = forecaster.forecast(currentShocks, futureShocks,
					currentStates);

			if (controlForecaster != null)
			{
				currentControls = controlForecaster.forecast(currentShocks,
						futureShocks, currentStates);
			}

			model_.afterAggregateTransition(currentStates,
					currentShocks.toArray(), futureShocks.toArray(), state_);

			// Record the data for this period (aggregate state and shocks)
			results.addPeriod(futureShocks, currentStates, currentControls);

			// The future shocks are now the current ones
			currentShocks = futureShocks;
		}

		// Write the results to a dated directory
		writeSimResults(results, outDir_, resultsPath_, "Agg");

		// Return the results
		return results;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see com.meliorbis.economics.infrastructure.ISimulator#
	 * simAggregatesForecastingControls(com.meliorbis.numerics.generic.impl.
	 * GenericBlockedArray,
	 * com.meliorbis.economics.infrastructure.simulation.PeriodAggregateState,
	 * MC, SC, java.io.File, java.lang.String)
	 */
	@Override
	public , MC extends ModelWithControls> SimulationResults simAggregatesForecastingControls(
			MultiDimensionalArray shocks_,
			PeriodAggregateState initialState_, MC model_, SC state_,
			File outDir_, String resultsPath_)
			throws ModelException, SimulatorException
	{
		throw new UnsupportedOperationException("Deprecated");
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see com.meliorbis.economics.infrastructure.ISimulator#simulate(int, int,
	 * com.meliorbis.economics.infrastructure.simulation.SimState, int[], M, S,
	 * com.meliorbis.economics.infrastructure.ISimulationObserver, java.io.File,
	 * java.lang.String)
	 */
	@Override
	public , M extends Model> SimulationResults simulate(
			int periods_, int burnIn_, DiscretisedDistribution initialState_,
			MultiDimensionalArray initialShockStates_, M model_,
			S calcState_,
			SimulationObserver observer_,
			File outputDir_, String resultsPath_)
			throws SimulatorException, ModelException
	{
		MultiDimensionalArray allShocks = createShockSequence(
				initialShockStates_, periods_ + burnIn_, model_);

		SimulationResults results = simulateShocks(
				initialState_, allShocks, model_, calcState_, observer_,
				outputDir_, resultsPath_);

		return results;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see com.meliorbis.economics.infrastructure.ISimulator#simulate(int, int,
	 * com.meliorbis.economics.infrastructure.simulation.SimState, int[], M, S,
	 * com.meliorbis.economics.infrastructure.ISimulationObserver)
	 */
	@Override
	public , M extends Model> SimulationResults simulate(
			int periods_, int burnIn_, DiscretisedDistribution initialState_,
			MultiDimensionalArray initialShockStates_, M model_,
			S calcState_,
			SimulationObserver observer_)
			throws SimulatorException, ModelException
	{
		MultiDimensionalArray allShocks = createShockSequence(
				initialShockStates_, periods_ + burnIn_, model_);

		return simulateShocks(initialState_, allShocks, model_, calcState_,
				observer_);
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * com.meliorbis.economics.infrastructure.ISimulator#simulateTransition(com.
	 * meliorbis.economics.infrastructure.simulation.SimState, M, S,
	 * java.lang.Integer[],
	 * com.meliorbis.numerics.generic.impl.GenericBlockedArray)
	 */
	@Override
	public , M extends Model> TransitionRecord simulateTransition(
			final DiscretisedDistribution distribution_, M model_, S calcState_,
			MultiDimensionalArray priorAggShockIndices_,
			MultiDimensionalArray futureShocks_)
			throws ModelException
	{
		Timer timer = new Timer();

		final int nAggExoStates = model_.getConfig()
				.getAggregateExogenousStateCount();
		final int nAggNormStates = model_.getConfig()
				.getAggregateNormalisingStateCount();
		final int nIndExoStates = model_.getConfig()
				.getIndividualExogenousStateCount();

		TransitionRecord record = new TransitionRecord();
		record.setShocks(priorAggShockIndices_);
		record.setFutureShocks(futureShocks_);

		Stoppable t = timer.start("calcAggs");
		/*
		 * First, determine the aggregate states
		 */
		calcAggStates(distribution_, model_, calcState_, priorAggShockIndices_,
				futureShocks_, record);
		t.stop();
		t = timer.start("calcCtrls");

		/*
		 * Then, determine the appropriate aggregate control values
		 */
		calculateControls(distribution_, model_, calcState_,
				priorAggShockIndices_, record);
		t.stop();
		t = timer.start("getTrans");

		// Fist we need to interpolate the transition function to the level of
		// aggregate states
		record.setTransitionAtAggs(getTransitionAtAggregates(distribution_,
				model_, calcState_, record.getStates(), record.getControls(),
				priorAggShockIndices_, futureShocks_));
		t.stop();
		// Get the grid of individual end-of-period state values
		DoubleArray targetGrid = calcState_
				.getEndOfPeriodStatesForSimulation();

		final DoubleArray resultingDensity = createArrayOfSize(
				distribution_._density.size());

		// In the lifecycle model, the age-0 distribution is conditonal
		// on the current aggregate state
		if (model_ instanceof ILifecycleModel)
		{
			resultingDensity.at(0).fill(
					((ILifecycleModel) model_).getZeroAgeDist(futureShocks_));
		}

		/*
		 * Now we need to get the individual transition probabilities given the
		 * aggregate state transition
		 */
		int[] transIndex = Utils.repeatArray(-1,
				model_.getConfig().getExogenousStateTransition().size().length);

		System.arraycopy(
				ArrayUtils.toPrimitive(
						((GenericBlockedArray) priorAggShockIndices_)
								.toArray()),
				0, transIndex, nIndExoStates, nAggExoStates);

		System.arraycopy(ArrayUtils.toPrimitive(
				((GenericBlockedArray) futureShocks_).toArray()), 0,
				transIndex, nIndExoStates * 2 + nAggExoStates,
				nAggExoStates + nAggNormStates);

		DoubleArray rawTransitionProbs = model_.getConfig()
				.getExogenousStateTransition().at(transIndex);

		final DoubleArray conditionalTransitionProbs = rawTransitionProbs
				.across(Utils.sequence(0, nIndExoStates))
				.divide(rawTransitionProbs.across(
						Utils.sequence(nIndExoStates, 2 * nIndExoStates))
						.sum());

		final DoubleArray overflowPropn = createArrayOfSize(
				distribution_._overflowProportions.size());
		final DoubleArray overflowAvgAmount = createArrayOfSize(
				distribution_._overflowAverages.size());

		record._resultingDist = new DiscretisedDistribution();
		record._resultingDist._density = resultingDensity;
		record._resultingDist._overflowProportions = overflowPropn;
		record._resultingDist._overflowAverages = overflowAvgAmount;

		model_.beforeSimInterpolation(distribution_, record, calcState_);

		t = timer.start("trans");
		/*
		 * First, distribute the transition function from the main grid onto the
		 * grid again
		 */
		transition(distribution_, record._resultingDist, targetGrid,
				record._transitionAtAggs, conditionalTransitionProbs, model_);
		t.stop();

		record.setExpectedPopulation(distribution_._density.sum()
				+ distribution_._overflowProportions.sum());

		model_.afterSimInterpolation(record, calcState_);

		// Make sure that our distribution still sums to 1
		if (!Precision.equals(
				record._resultingDist._density.sum()
						+ record._resultingDist._overflowProportions.sum(),
				(model_ instanceof ILifecycleModel)
						? ((ILifecycleModel) model_).getNumberOfGenerations()
						: record.getExpectedPopulation(),
				1e-10))
		{
			throw new RuntimeException(
					"Some people have emigrated. This is not acceptable!");
		}

		return record;
	}

	private , M extends Model, SC extends StateWithControls, MC extends ModelWithControls> void calculateControls(
			DiscretisedDistribution distribution_, M model_, S calcState_,
			MultiDimensionalArray priorAggShockIndices_,
			TransitionRecord record_)
			throws ModelException
	{
		if (model_.getConfig().getAggregateControlCount() > 0)
		{
			// Need to cast both to related types
			@SuppressWarnings("unchecked")
			MC modelWithControls = (MC) model_;
			@SuppressWarnings("unchecked")
			SC stateWithControls = (SC) calcState_;

			// Let the model calculate the controls given the other state
			double[] impliedAggControls = modelWithControls
					.calculateAggregateControls(distribution_,
							record_._transitionAtAggs,
							record_.getStates().toArray(),
							priorAggShockIndices_, stateWithControls);

			// Also store the values of controls
			record_.setControls((DoubleArray) DoubleArrayFactories
					.createArray(impliedAggControls));

		} else
		{
			// No controls!
			record_.setControls((DoubleArray) createArrayOfSize(0));
		}

		if (LOG.isLoggable(Level.FINE))
		{
			LOG.fine("Aggregate Controls: " + record_.getControls());
		}
	}

	private > void calcAggStates(
			DiscretisedDistribution distribution_, Model model_,
			S calcState_,
			MultiDimensionalArray priorAggShockIndices_,
			MultiDimensionalArray futureShocks_,
			TransitionRecord record_)
			throws ModelException
	{
		/*
		 * Calculate the current aggregate states
		 */
		record_.setStates((DoubleArray) DoubleArrayFactories
				.createArray(model_.calculateAggregateStates(distribution_,
						priorAggShockIndices_, calcState_)));

		if (LOG.isLoggable(Level.FINE))
		{
			LOG.fine("Aggregate States: " + record_.getStates());
		}
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see com.meliorbis.economics.infrastructure.simulation.
	 * DiscretisedDistributionSimulator#transition(
	 * com.meliorbis.economics.infrastructure.simulation.
	 * DiscretisedDistribution,
	 * com.meliorbis.economics.infrastructure.simulation.
	 * DiscretisedDistribution,
	 * com.meliorbis.numerics.generic.primitives.IDoubleArray,
	 * com.meliorbis.numerics.generic.primitives.IDoubleArray,
	 * com.meliorbis.numerics.generic.primitives.IDoubleArray,
	 * com.meliorbis.economics.model.IModel )
	 */
	@Override
	public void transition(final DiscretisedDistribution sourceDist_,
			final DiscretisedDistribution targetDist_,
			final DoubleArray gridPoints_,
			final DoubleArray transitionFn_,
			final DoubleArray exoTransProbs_, final Model model_)
	{
		if (model_ instanceof ILifecycleModel)
		{
			List callables = new ArrayList();

			for (int age = 0; age < ((ILifecycleModel) model_)
					.getNumberOfGenerations() - 1; age++)
			{
				final int ageForCall = age;

				callables.add(new ComputableRecursiveAction()
				{
					@Override
					public void compute()
					{
						transitionForAge(sourceDist_, targetDist_, gridPoints_,
								transitionFn_, exoTransProbs_, model_,
								ageForCall);
					}
				});
			}

			getNumerics().getExecutor().executeAndWait(callables);
		} else
		{
			transitionForAge(sourceDist_, targetDist_, gridPoints_,
					transitionFn_, exoTransProbs_, model_, -1);
		}
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * com.meliorbis.economics.infrastructure.ISimulator#transitionForAge(com.
	 * meliorbis.economics.infrastructure.simulation.SimState,
	 * com.meliorbis.economics.infrastructure.simulation.SimState,
	 * com.meliorbis.numerics.generic.primitives.IDoubleArray,
	 * com.meliorbis.numerics.generic.primitives.IDoubleArray,
	 * com.meliorbis.numerics.generic.primitives.IDoubleArray,
	 * com.meliorbis.economics.model.IModel, int)
	 */
	public void transitionForAge(final DiscretisedDistribution sourceDist_,
			DiscretisedDistribution targetDist_, DoubleArray gridPoints_,
			DoubleArray y_, final DoubleArray exoTransProbs_,
			Model model_, int age_)
	{
		distribute(sourceDist_._density, y_, exoTransProbs_, targetDist_,
				gridPoints_, model_.getConfig().isConstrained(), age_);

		// Since the simulation grid is only the size of a distribution, but we
		// need targets to have the extra potential multi-var dimension,
		// create a new array and duplicate overflow for each variable
		int[] ofSizeAllVars = Arrays.copyOf(
				sourceDist_._overflowAverages.at(age_).size(),
				sourceDist_._overflowAverages.at(age_).numberOfDimensions()
						+ 1);
		ofSizeAllVars[ofSizeAllVars.length - 1] = y_
				.size()[y_.numberOfDimensions() - 1];
		DoubleArray ofAllVars = createArrayOfSize(ofSizeAllVars);

		ofAllVars.fillDimensions(sourceDist_._overflowAverages.at(age_),
				Utils.sequence(0, ofSizeAllVars.length - 1));

		/*
		 * Now we need to deal with overflows: - First, determine kp given k for
		 * those who have overflowed
		 */
		DoubleArray overflowTargets = interpolateFunctionAcross(gridPoints_,
				y_.at(age_), 0, ofAllVars,
				new Params().constrained(model_.getConfig().isConstrained()),
				y_.numberOfDimensions() - 1);

		/*
		 * Now, do exactly the same as above for the normal transition
		 */
		distribute(sourceDist_._overflowProportions, overflowTargets,
				exoTransProbs_, targetDist_, gridPoints_,
				model_.getConfig().isConstrained(), age_, true);
	}

	/**
	 * @param model_
	 * @param calcState_
	 * @param currentAggStates_
	 * @param priorAggShockIndices_
	 * @param futureAggShockIndices_
	 * @return
	 */
	private , M extends Model> DoubleArray getTransitionAtAggregates(
			DiscretisedDistribution distribution_, M model_, S calcState_,
			DoubleArray currentAggStates_,
			DoubleArray currentAggControls_,
			MultiDimensionalArray priorAggShockIndices_,
			MultiDimensionalArray futureAggShockIndices_)
	{
		// If the model has provided a conditional transition function, use it
		if (model_ instanceof HasConditionalTransition)
		{
			return ((HasConditionalTransition) model_)
					.getTransitionAtAggregateState(distribution_, calcState_,
							currentAggStates_, currentAggControls_,
							priorAggShockIndices_, futureAggShockIndices_);
		}

		// Here, the simulator performs the default mechanism
		DoubleArray individualStateTransitions = calcState_
				.getIndividualPolicyForSimulation();

		// Adjust for lifecycle if necessary
		int lifecycleOffset = (model_ instanceof ILifecycleModel) ? 1 : 0;

		int aggDetStateStart = model_.getAggDetStateStart() + lifecycleOffset;
		int aggStochStateStart = model_.getAggStochStateStart()
				+ lifecycleOffset;

		// This array is used to select the appropriate slices for shocks and,
		// when appropriate, controls which are constant
		int[] sliceSelector = Utils.repeatArray(-1,
				individualStateTransitions.size().length);

		// Copy current aggregate shocks to the selector
		System.arraycopy(
				ArrayUtils.toPrimitive(
						((GenericBlockedArray) priorAggShockIndices_)
								.toArray()),
				0, sliceSelector, aggStochStateStart,
				priorAggShockIndices_.numberOfElements());

		// If there is a normalising shock to be applied in the future period,
		// copy that also
		// TODO: Is this right? Seems wrong - since this is the current period
		// transition should be copying the current period shock, if anything
		if (model_.getConfig().getAggregateNormalisingStateCount() > 0)
		{
			sliceSelector[sliceSelector.length
					- 2] = ((GenericBlockedArray) futureAggShockIndices_)
							.last();
		}

		boolean interpControls = false;

		final int nAggStates = currentAggStates_.numberOfElements();
		final int nAggControls = currentAggControls_.numberOfElements();

		if (nAggControls > 0)
		{

			// If the control only has one possible value, then this must be a
			// no-agg-risk calc and
			// all controls will be such - don't interpolate but select
			if (model_.getConfig().getAggregateControls().get(0)
					.numberOfElements() == 1)
			{
				// Select the only available index, 0
				System.arraycopy(Utils.repeatArray(0, nAggControls), 0,
						sliceSelector, aggDetStateStart + nAggStates,
						nAggControls);
			} else
			{
				// Otherwise, controls also need to be interpolated
				interpControls = true;
			}
		}

		/*
		 * Construct the specifications for interpolating along aggregate state
		 * and, if necessary, control dimensions
		 */
		Specification[] dimSpecs = new Specification[nAggStates
				+ (interpControls ? nAggControls : 0)];

		int i = 0;

		while (i < nAggStates)
		{
			dimSpecs[i] = spec(aggDetStateStart + i,
					model_.getConfig().getAggregateEndogenousStates().get(i),
					currentAggStates_.get(i));
			i++;
		}

		if (interpControls)
		{
			i = 0;
			while (i < nAggControls)
			{
				dimSpecs[nAggStates + i] = spec(
						aggDetStateStart + nAggStates + i,
						model_.getConfig().getAggregateControls().get(i),
						currentAggControls_.get(i));
				i++;
			}
		}

		// Return an appropriately selected and interpolated slice of the
		// transition function
		DoubleArray result = interp(
				individualStateTransitions.at(sliceSelector), dimSpecs);

		return result;
	}

	/**
	 * Maps a density to a new density under a given transition function
	 *
	 * @param priorDensity_
	 *            The initial density
	 * @param transitionFunction_
	 *            The grid values to interpolate on to
	 * @param transitionProbs_
	 *            The state transition probabilities across the exogenous
	 *            dimension
	 * @param targetGrid_
	 *            The X values of the transition function
	 * @param transitionFunction_
	 *            The Y values of the transition function
	 * @param constrained_
	 *            Indicates whether the function is constrained to be positive
	 */
	private void distribute(DoubleArray priorDensity_,
			DoubleArray transitionFunction_,
			final DoubleArray transitionProbs_,
			final DiscretisedDistribution targetState_,
			DoubleArray targetGrid_, boolean constrained_, final int age_)
	{
		distribute(priorDensity_, transitionFunction_, transitionProbs_,
				targetState_, targetGrid_, constrained_, age_, false);
	}

	@SuppressWarnings("unchecked")
	void distribute(DoubleArray priorDensity_,
			DoubleArray transitionFunction_,
			final DoubleArray transitionProbs_,
			DiscretisedDistribution targetState_, DoubleArray targetGrid_,
			boolean constrained_, final int age_, boolean targetForAge_)
	{
		assert transitionProbs_
				.across(Utils.sequence(
						transitionProbs_.numberOfDimensions() / 2,
						transitionProbs_.numberOfDimensions()))
				.sum()
				.subtract(createArrayOfSize(
						ArrayUtils.subarray(transitionProbs_.size(),
								transitionProbs_.numberOfDimensions() / 2,
								transitionProbs_.numberOfDimensions()))
										.fill(1d))
				.map(DoubleArrayFunctions.abs)
				.max() < 1e-10 : "Transitions probabilities from each state must sum to 1";
		/*
		 * Figure out whether this is byAge or not
		 */
		final boolean byAge = age_ != -1;

		final DoubleArray priorDensityArray = byAge ? priorDensity_.at(age_)
				: priorDensity_;

		if (byAge && !targetForAge_)
		{
			transitionFunction_ = transitionFunction_.at(age_);
		}

		/*
		 * THIS ASSUMES THAT THERE IS ONE STATE AND ONE SHOCK
		 */
		DiscretisedDistribution[] intermediateStates = new DiscretisedDistribution[targetState_._density
				.size()[1]];
		DoubleArray[] densities = new DoubleArray[targetState_._density
				.size()[1]];
		DoubleArray[] oas = new DoubleArray[targetState_._density.size()[1]];
		DoubleArray[] ops = new DoubleArray[targetState_._density.size()[1]];

		for (int i = 0; i < intermediateStates.length; i++)
		{
			intermediateStates[i] = targetState_.createSameSized();
			densities[i] = intermediateStates[i]._density;
			oas[i] = intermediateStates[i]._overflowAverages;
			ops[i] = intermediateStates[i]._overflowProportions;
		}

		final Timer timer = new Timer();
		/*
		 * Note that the function we are interpolating is an identity - and it
		 * does not matter since we are really just interested in the
		 * x-interpolation proportions
		 */
		interpolateFunctionAcross(targetGrid_, targetGrid_, 0,
				transitionFunction_,
				new Params().constrained(constrained_)
						.withCallback((targetValue_, index_, lowerTargetIndex_,
								lowerTargetProportion_) -> {
							Stoppable t = timer.start("SimCB");
							try
							{
								double priorDensity = priorDensityArray
										.get(index_);

								DiscretisedDistribution currentTargetState = intermediateStates[index_[1]];

								// Skip if prior density is precisely 0
								if (0d == priorDensity)
								{
									return;
								}

								/*
								 * Did we overflow the grid? (propn is between 0
								 * and 1 within the grid)
								 */
								DoubleArray probs = transitionProbs_
										.at(Arrays.copyOfRange(index_, 1,
												index_.length));

								if (lowerTargetProportion_ < 0)
								{
									// Yes! Synchronize on the overflow since we
									// multiple threads could update it
									// simultaneously

									// Distribute across individual states
									DoubleArray newOverflow = probs
											.multiply(priorDensity);

									DoubleArray op = byAge
											? currentTargetState._overflowProportions
													.at(age_ + 1)
											: currentTargetState._overflowProportions;
									DoubleArray oa = byAge
											? currentTargetState._overflowAverages
													.at(age_ + 1)
											: currentTargetState._overflowAverages;

									// for each point...
									oa.modifying().with(op, newOverflow).map(
											(DoubleNaryOp) inputs_ -> {

												/*
												 * ...calculate the new mean
												 * wealth level of each superich
												 * by taking the weighted
												 * averrage of those already
												 * there and the nouveaux riches
												 */
												double existingWealth = inputs_[0];
												double existingWealthy = inputs_[1];
												double newWealthy = inputs_[2];

												double allWealthy = existingWealthy
														+ newWealthy;

												return allWealthy == 0d ? 0d
														: (/* overflowAvgAmount */existingWealth
																* /* overflowPropn */existingWealthy
																+ targetValue_
																		* /* newOverflow */newWealthy)
																/ allWealthy;
											});

									// Also add the new wealthy to the amount of
									// wealthy
									op.modifying().add(newOverflow);

									return;
								}

								int[] targetSubIndex = Utils.repeatArray(-1,
										index_.length + (byAge ? 1 : 0));

								int targetDimension;

								if (byAge)
								{
									targetSubIndex[0] = age_ + 1;
									targetDimension = 1;
								} else
								{
									targetDimension = 0;
								}

								if (lowerTargetProportion_ > 1d)
								{
									// LOG.warning("Grid Underflow!");
									lowerTargetProportion_ = 1d;
								}

								if (lowerTargetProportion_ != 0d)
								{
									final double lowerMultiplicant = lowerTargetProportion_
											* priorDensity;

									targetSubIndex[targetDimension] = lowerTargetIndex_;

									currentTargetState._density
											.at(targetSubIndex).modifying()
											.with(probs)
											.map((DoubleBinaryOp) (
													dens, probability) -> dens
															+ probability * lowerMultiplicant);
								}

								if (lowerTargetProportion_ != 1d)
								{
									final double upperMultiplicant = priorDensity
											- lowerTargetProportion_
													* priorDensity;

									targetSubIndex[targetDimension] = lowerTargetIndex_
											+ 1;

									currentTargetState._density
											.at(targetSubIndex).modifying()
											.with(probs)
											.map((DoubleBinaryOp) (
													left_, right_) -> left_
															+ right_ * upperMultiplicant);
								}
							} finally
							{
								t.stop();
							}

						}),
				transitionFunction_.numberOfDimensions() - 1);

		final DoubleNaryOp sumOp = (values) -> {
			double sum = 0d;

			for (double value : values)
			{
				sum += value;
			}

			return sum;
		};

		// The target density is the sum of all the densities - INCLUDING THE
		// TARGET, WHICH MAY ALREADY CONTAIN STUFF
		targetState_._density.modifying().with(densities).map(sumOp);

		// Prepare the OAs for the operation
		targetState_._overflowAverages.modifying()
				.multiply(targetState_._overflowProportions);

		// OverflowProportions are just sums
		targetState_._overflowProportions.modifying().with(ops).map(sumOp);

		// Multiply the overflow values by the proportions...
		for (int i = 0; i < oas.length; i++)
		{
			oas[i].modifying().multiply(ops[i]);
		}

		// ...make the sum and divide by the sum of the proportions
		targetState_._overflowAverages.modifying().with(oas).map(sumOp)
				.with(targetState_._overflowProportions)
				.map((DoubleBinaryOp) (oaAvg, op) -> op == 0d
						? 0d : oaAvg / op);
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see com.meliorbis.economics.infrastructure.ISimulator#createState()
	 */
	@Override
	public DiscretisedDistribution createState()
	{
		return new DiscretisedDistribution();
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see com.meliorbis.economics.infrastructure.ISimulator#findErgodicDist(M,
	 * S, com.meliorbis.economics.infrastructure.simulation.SimState)
	 */
	@Override
	public , M extends Model> DiscretisedDistribution findErgodicDist(
			M model_, S state_, DiscretisedDistribution simState_)
			throws ModelException, SimulatorException
	{
		return steadySinglePointSim(model_, state_, simState_);
	}

	private , M extends Model, SC extends StateWithControls, MC extends ModelWithControls> DiscretisedDistribution steadySinglePointSim(
			M model_, S state_, DiscretisedDistribution simState_)
			throws ModelException, SimulatorException
	{
		DoubleArray lastState = simState_._density.copy();

		DoubleArray lastAggs = null;

		double criterion = Double.POSITIVE_INFINITY;

		int periods = 0;

		int periodsPerStep = 20;

		IntegerArray shocks = getNumerics().newIntArray(periodsPerStep,
				model_.getConfig().getAggregateExogenousStateCount() + model_
						.getConfig().getAggregateNormalisingStateCount());

		System.out.println("\nFinding Ergodic Distribution\n");

		double minCriterion = criterion;
		double lastAggsCriterion = Double.POSITIVE_INFINITY;

		int minCritPeriod = 0;
		int aggDownTrend = 0;

		// If the simulation stages are to be traced, make sure the directory
		// exists
		if (_logSims != null)
		{
			new File(_logSims).mkdirs();
		}

		do
		{
			simState_ = simulateShocks(simState_, shocks, model_, state_,
					SimulationObserver
							. silent())
									.getFinalState();

			// Short-circuit if everything is overflowing...
			if (simState_._overflowProportions.sum() > 0.1)
			{
				break;
			}

			if (criterion < 1e-4)
			{
				double[] impliedAggs = model_.calculateAggregateStates(
						simState_, shocks.at(0), state_);

				@SuppressWarnings("unchecked")
				double[] impliedControls = model_.getConfig()
						.getAggregateControlCount() > 0
								? ((MC) model_).calculateAggregateControls(
										simState_,
										state_.getAggregateTransition(),
										impliedAggs, shocks.at(0), (SC) state_)
								: new double[0];

				DoubleArray newAggs = getNumerics().getArrayFactory()
						.newArray((double[]) ArrayUtils.addAll(impliedAggs,
								impliedControls));

				if (lastAggs == null)
				{
					lastAggs = newAggs;
					continue;
				}

				double aggsCriterion = maximumRelativeDifferenceSpecial(
						lastAggs, newAggs);

				// Need '<=' here to cope with continuing 0s (other repetitions
				// are unlikely)
				if (aggsCriterion <= lastAggsCriterion)
				{
					aggDownTrend++;
				} else
				{
					aggDownTrend = 0;
				}

				lastAggsCriterion = aggsCriterion;

				lastAggs = newAggs;
			}

			if (_logSims != null)
			{
				final NumericsWriter writer = getNumericsWriter(
						new File(_logSims, Integer.toString(periods)));

				try
				{
					simState_.write(writer);
				} catch (IOException e)
				{
					LOG.log(Level.WARNING, "Error writing sim state", e);
				}
			}

			criterion = maximumRelativeDifferenceSpecial(simState_._density,
					lastState);

			if (criterion < minCriterion)
			{
				minCriterion = criterion;
				minCritPeriod = periods;
			}

			lastState = simState_._density.copy();
			// lastOverflow = simState_._overflowProportions.copy();0

			if ((periods += periodsPerStep) % 100 == 0)
			{
				LOG.info(String.format(
						"Periods: %s, Precision: %.2e, Aggregate Precision: %.2e",
						periods, criterion, lastAggsCriterion));

				System.out.print(".");

				if (periods % 1000 == 0)
				{
					System.out.print("\n");
				}
			}

			if (lastAggsCriterion < 1e-6
					&& (aggDownTrend >= 5 || lastAggsCriterion < 1e-10)
					&& criterion < 1e-6/* && periods > 1000 */)
			{
				System.out.println("\nFound ergodic distribution\n");
				break;
			} else if (aggDownTrend < 100 && periods > 3000
					|| ((periods - minCritPeriod) > 1000 && aggDownTrend < 5
							&& periods > 20000))
			{
				throw new RuntimeException(
						"\nFailed to converge to ergodic distribution after "
								+ periods + " periods");
			}
		} while (true);

		return simState_;

	}
}