com.meliorbis.economics.aggregate.ks.KrusellSmithSolver Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of ModelSolver Show documentation
Show all versions of ModelSolver Show documentation
A library for solving economic models, particularly
macroeconomic models with heterogeneous agents who have model-consistent
expectations
/**
*
*/
package com.meliorbis.economics.aggregate.ks;
import static com.meliorbis.numerics.DoubleArrayFactories.createArrayOfSize;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import org.apache.commons.lang.ArrayUtils;
import com.meliorbis.economics.aggregate.AggregateSolverBase;
import com.meliorbis.economics.infrastructure.simulation.SimState;
import com.meliorbis.economics.infrastructure.simulation.SimulationObserver;
import com.meliorbis.economics.infrastructure.simulation.SimulationResults;
import com.meliorbis.economics.infrastructure.simulation.Simulator;
import com.meliorbis.economics.model.Model;
import com.meliorbis.economics.model.ModelConfig;
import com.meliorbis.economics.model.ModelException;
import com.meliorbis.economics.model.State;
import com.meliorbis.numerics.DoubleArrayFactories;
import com.meliorbis.numerics.IntArrayFactories;
import com.meliorbis.numerics.generic.IntegerArray;
import com.meliorbis.numerics.generic.primitives.DoubleArray;
import com.meliorbis.numerics.generic.primitives.impl.DoubleArrayFunctions;
import com.meliorbis.numerics.index.impl.Index;
import com.meliorbis.utils.Pair;
import com.meliorbis.utils.Utils;
/**
* An implementation of the Krusell-Smith (98) algorithm for updating the aggregate
* forecasting function.
*
* When called the solver simulates the model for a number of periods using random shocks, discards
* data from the initial periods and then estimates the forecasting function from the
* time series of aggregate states and shocks obtained.
*
* A number of features are configurable:
*
* - Number of simulation periods
* - Number of periods to discard
* - Log-linear or linear forecasting rule
* - Simulate from initial distribution vs. keeping the last distribution from prior iteration
* - Reusing shocks vs. creating new shocks each iteration
*
*
* @author Tobias Grasl
*
* @param The Config type
* @param The State type
* @param The Model type
* @param The Simulation-state type
*/
public class KrusellSmithSolver,
M extends Model,
T extends SimState> extends AggregateSolverBase
{
private static final Logger LOG = Logger.getLogger(KrusellSmithSolver.class.getName());
private Simulator _simulator;
private final boolean _hasControls;
private int _simPeriods = 10000;
private int _discardPeriods = 1000;
private boolean _logs = false;
private boolean _keepDist = false;
private boolean _reuseShocks = true;
private T _startingDist = null;
private IntegerArray> _startingShocks = null;
private IntegerArray> _shocks = null;
public KrusellSmithSolver(M model_, C config_,
Simulator simulator_)
{
super(model_, config_);
_simulator = simulator_;
_hasControls = _config.getAggregateControlCount() > 0;
}
@SuppressWarnings("unchecked")
@Override
protected Pair, DoubleArray>> calculateAggregatePolicies(S state_)
{
prepareAggregatePolicyCalculation(state_);
T simState;
IntegerArray> initialShocks;
if(_keepDist && _startingDist != null) {
simState = _startingDist;
initialShocks = _startingShocks;
} else {
simState = (T) _config.getInitialSimState();
initialShocks = (IntegerArray>) _config.getInitialExogenousStates();
}
IntegerArray> shocks;
if(_reuseShocks) {
if(_shocks == null) {
_shocks = (IntegerArray>) _simulator.createShockSequence(initialShocks, _simPeriods, _model);
// Make sure end shocks are equal to initial for next round
_shocks.at(_shocks.size()[0]-1).fill(initialShocks);
}
shocks = _shocks;
} else {
shocks = (IntegerArray>) _simulator.createShockSequence(initialShocks, _simPeriods, _model);
}
try
{
LOG.info(String.format("Simulating %s periods for estimation",_simPeriods));
SimulationResults results =
_simulator.simulateShocks(simState,
shocks,
_model,
state_,
SimulationObserver.silent());
if(_keepDist) {
_startingShocks = shocks.at(shocks.size()[0]-1).copy();
_startingDist = results.getFinalState();
}
final DoubleArray> states = results.getStates();
final DoubleArray> controls = results.getControls();
if(_logs) {
states.modifying().map(DoubleArrayFunctions.log);
if(controls != null) {
controls.modifying().map(DoubleArrayFunctions.log);
}
}
List shockCounts = new ArrayList();
Utils.addLengthsToList(shockCounts, _config.getAggregateExogenousStates());
IntegerArray> countsArray = IntArrayFactories.createIntArrayOfSize(shockCounts);
// First, count how often each range of combination of shocks occurs
IntStream.range(_discardPeriods, shocks.size()[0]-2).forEach(period -> {
int[] currentShocks = ArrayUtils.toPrimitive(shocks.at(period).toArray());
if(_config.getAggregateNormalisingStateCount() > 0) {
currentShocks = ArrayUtils.subarray(currentShocks, 0, _config.getAggregateEndogenousStateCount());
}
countsArray.set(countsArray.get(currentShocks)+1, currentShocks);
});
DoubleArray>[] inputStates = new DoubleArray[countsArray.numberOfElements()];
DoubleArray>[] resultingStates = new DoubleArray[countsArray.numberOfElements()];
DoubleArray>[] resultingControls =
_hasControls ? new DoubleArray[countsArray.numberOfElements()] : null;
// Create Arrays to hold inputs and outputs
IntStream.range(0, inputStates.length).forEach(index -> {
inputStates[index] =
DoubleArrayFactories.
createArrayOfSize(countsArray.get(index),
_config.getAggregateEndogenousStateCount());
resultingStates[index] = DoubleArrayFactories.createArrayOfSize(
countsArray.get(index),
_config.getAggregateEndogenousStateCount());
if(_hasControls) {
resultingControls[index] = DoubleArrayFactories.createArrayOfSize(countsArray.get(index), _config.getAggregateControlCount());
}
});
Index index = new Index(countsArray.size());
// Reset the counts
countsArray.fill(0);
IntStream.range(_discardPeriods, shocks.size()[0]-2).forEach(period -> {
int[] currentShocks = ArrayUtils.toPrimitive(shocks.at(period).toArray());
if(_config.getAggregateNormalisingStateCount() > 0) {
currentShocks = ArrayUtils.subarray(currentShocks, 0, _config.getAggregateEndogenousStateCount());
}
int linearIndex = index.toLinearIndex(currentShocks);
int nthInstance = countsArray.get(currentShocks);
DoubleArray> currentStates = states.at(period);
DoubleArray> futureStates = states.at(period+1);
inputStates[linearIndex].at(nthInstance).fill(currentStates);
resultingStates[linearIndex].at(nthInstance).fill(futureStates);
if(_hasControls) {
DoubleArray> currentControls = controls.at(period);
resultingControls[linearIndex].at(nthInstance).fill(currentControls);
}
// Increment the count for that combination of shocks
countsArray.set(nthInstance+1, currentShocks);
});
List transitionDimensions = new ArrayList();
Utils.addLengthsToList(transitionDimensions, _config.getAggregateEndogenousStates());
Utils.addLengthsToList(transitionDimensions, _config.getAggregateExogenousStates());
transitionDimensions.add(_config.getAggregateEndogenousStateCount());
DoubleArray> newTransition = createArrayOfSize(transitionDimensions);
transitionDimensions.set(_config.getAggregateControlCount(),
transitionDimensions.size()-1);
DoubleArray> newControlsPolicy= createArrayOfSize(transitionDimensions);
// Add ones for the constant
IntStream.range(0, inputStates.length).forEach(idx -> {
int[] currentShocks = index.toLogicalIndex(idx);
int[] selector = new int[newTransition.numberOfDimensions()];
Arrays.fill(selector, -1);
System.arraycopy(currentShocks, 0, selector, _config.getAggregateEndogenousStateCount(),
currentShocks.length);
DoubleArray> X = DoubleArrayFactories.
createArrayOfSize(inputStates[idx].size()[0],1).
fill(1).stackFinal(inputStates[idx]);
DoubleArray> predictor = X.transpose(0, 1).matrixMultiply(X).inverseMatrix();
DoubleArray> stateEst = predictor.matrixMultiply(
X.transpose(0,1).matrixMultiply(resultingStates[idx]));
LOG.log(Level.INFO, "State estimates:\n"+stateEst.toString());
fillTransition(newTransition.at(selector),stateEst);
if(_hasControls) {
DoubleArray> controlEst = predictor.matrixMultiply(
X.transpose(0,1).matrixMultiply(resultingControls[idx]));
LOG.log(Level.INFO, "Control estimates:\n"+controlEst.toString());
fillTransition(newControlsPolicy.at(selector),controlEst);
}
});
DoubleArray> fullTransition =
_model.createAggregateVariableGrid(_config.getAggregateEndogenousStateCount());
int[] dimsToFill = ArrayUtils.addAll(Utils.sequence(0, _config.getAggregateEndogenousStateCount()),
Utils.sequence(_config.getAggregateEndogenousStateCount()+_config.getAggregateControlCount(),_config.getAggregateEndogenousStateCount()+_config.getAggregateControlCount() + _config.getAggregateExogenousStateCount()+1));
fullTransition.fillDimensions(newTransition, dimsToFill);
if(_logs) {
fullTransition.modifying().map(DoubleArrayFunctions.exp);
}
IntStream.range(0, _config.getAggregateControlCount()).forEach(idx -> {
if(_config.getAggregateControls().get(idx).numberOfElements() == 1) return;
double min = _config.getAggregateControls().get(idx).first();
double max = _config.getAggregateControls().get(idx).last();
fullTransition.lastDimSlice(idx).modifying().map(
DoubleArrayFunctions.cutToBounds(min, max));
});
DoubleArray> fullControls;
if(_config.getAggregateControlCount() > 0) {
fullControls = _model.createAggregateVariableGrid(_config.getAggregateControlCount());
fullControls.fillDimensions(newControlsPolicy, dimsToFill);
if(_logs) {
fullControls.modifying().map(DoubleArrayFunctions.exp);
}
} else {
fullControls = null;
}
return new Pair, DoubleArray>>(fullTransition, fullControls);
} catch (ModelException e)
{
throw new RuntimeException("Error updating policies", e);
}
}
private void fillTransition(DoubleArray> newTransition, DoubleArray> estimates_)
{
// Fill the transition with the intercepts
newTransition.fillDimensions(estimates_.at(0), newTransition.numberOfDimensions() - 1);
// Now add the part resulting from variation in the value of each // current aggregate
for (int currentAgg = 0; currentAgg < _config.getAggregateEndogenousStateCount();
currentAgg++)
{
DoubleArray> aggContribution = createArrayOfSize(newTransition.size());
DoubleArray> points;
if(_logs) {
points = _config.getAggregateEndogenousStates().get(currentAgg).map(DoubleArrayFunctions.log);
} else {
points = _config.getAggregateEndogenousStates().get(currentAgg);
}
// Add the variation in the current aggregate from it's state where
// the approx. is taken
aggContribution.fillDimensions(points,currentAgg);
// Now multiply by the gradients
aggContribution.modifying().across(aggContribution.numberOfDimensions() - 1).
multiply(estimates_.at(currentAgg+1));
newTransition.modifying().add(aggContribution);
}
}
/**
* Sets the number of periods to be simulated. Default 10000.
*
* @param simPeriods_ The number of periods to be simulated
*/
protected void setSimPeriods(int simPeriods_)
{
_simPeriods = simPeriods_;
}
/**
* Sets the number of periods to be discarded. Default 1000.
*
* @param discardPeriods_ The number of periods to be discarded
*/
protected void setDiscardPeriods(int discardPeriods_)
{
_discardPeriods = discardPeriods_;
}
/**
* Indicates whether a log-linear rule should be estimated. Defaults to
* false.
*
* @param useLogs_ True for log-linear, false for linear.
*/
protected void useLogs(boolean useLogs_)
{
_logs = useLogs_;
}
/**
* Indicates whether the final distribution of each simulation run should
* be used to start the next.
*
* @param keepDist_ True causes the final distribution to be used to start the next iteration.
*/
protected void keepDist(boolean keepDist_)
{
_keepDist = keepDist_;
}
/**
* Indicates whether the same shock sequence should be used for each simulation run
*
* @param reuseShocks_ True reuses shocks, false creates new ones each iteration.
*/
protected void reuseShocks(boolean reuseShocks_)
{
_reuseShocks = reuseShocks_;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy