All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.meliorbis.economics.infrastructure.AbstractModelWithControls Maven / Gradle / Ivy

/**
 * 
 */
package com.meliorbis.economics.infrastructure;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.IntFunction;
import java.util.stream.IntStream;

import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang.NotImplementedException;

import com.meliorbis.economics.infrastructure.simulation.SimState;
import com.meliorbis.economics.model.ModelConfig;
import com.meliorbis.economics.model.ModelException;
import com.meliorbis.economics.model.ModelWithControls;
import com.meliorbis.economics.model.StateWithControls;
import com.meliorbis.numerics.NumericsException;
import com.meliorbis.numerics.function.primitives.DoubleGridFunctionFactory;
import com.meliorbis.numerics.generic.MultiDimensionalArray;
import com.meliorbis.numerics.generic.impl.IntegerArray;
import com.meliorbis.numerics.generic.primitives.DoubleArray;
import com.meliorbis.numerics.generic.primitives.impl.DoubleArrayFunctions;
import com.meliorbis.numerics.io.NumericsWriter;
import com.meliorbis.utils.Timer;
import com.meliorbis.utils.Timer.Stoppable;

/**
 * @author Tobias Grasl
 * 
 * @param  The Config type
 * @param  The State type
 */
public abstract class AbstractModelWithControls> extends AbstractModel implements ModelWithControls
{
	
	
	private int[] _controlDims;
	private int[] _expnDims;
	private int[] _ctrlSelector;
	private DoubleArray[] _restrictedTargets;
	private DoubleArray[] _restrictedControls;
	private IntFunction[]> _arrayCreator;


	@Override
	public void initialise()
	{
		super.initialise();
		
		List controlDims = new ArrayList();
		List expnDims = new ArrayList();
		List selector = new ArrayList();
		
		List> aggregateControls = _config.getAggregateControls();
	
		for (int i = 0; i < aggregateControls.size(); i++)
		{
			if(aggregateControls.get(i).numberOfElements() > 1)
			{
				controlDims.add(i);
				selector.add(-1);
			}
			else
			{
				expnDims.add(i);
				selector.add(0);
			}
		}
		
		_controlDims = ArrayUtils.toPrimitive(controlDims.toArray(new Integer[controlDims.size()]));
		_expnDims = ArrayUtils.toPrimitive(expnDims.toArray(new Integer[expnDims.size()]));
		_ctrlSelector = ArrayUtils.toPrimitive(selector.toArray(new Integer[selector.size()]));
		
		List> controlTargets = _config.getControlTargets();
		List> allControls = _config.getAggregateControls();
		
		_arrayCreator = (size)->new DoubleArray[size];
		
		_restrictedTargets = IntStream.of(_controlDims).
			mapToObj((idx)->controlTargets.get(idx)).toArray(
					_arrayCreator);
		
		_restrictedControls = IntStream.of(_controlDims).
			mapToObj((idx)->allControls.get(idx)).toArray(_arrayCreator);
	}
	
	

	@Override
	public void writeAdditional(S state_, NumericsWriter writer_) throws IOException
	{
		super.writeAdditional(state_, writer_);
		
		writer_.writeArray("aggCtrlsPolicy", state_.getCurrentControlsPolicy());
	    writer_.writeArray("indCtrlsPolicy", state_.getIndividualControlsPolicy());
	    writer_.writeArray("indCtrlsPolicySim", state_.getIndividualControlsPolicyForSimulation());
	}



	/**
	 * Given the individual transition function dependent on aggregate controls,
	 * calculate the actual implied aggregate controls from the individual
	 * transition at each combination of aggregate controls.
	 * 
	 * The default implementation does nothing
	 *
	 * @param individualTransitionByAggregateControl_
	 *            The individual transition function, conditional on the grid
	 *            values of aggregate controls
	 * @param calcState_
	 *            The state of the calculation
	 *
	 * @return A grid that has one dimension for each aggregate control plus one
	 *         across the different controls, sized according to the number of
	 *         values of that control we are solving the model for, which
	 *         indicates for each point what the implied aggregate control
	 *         values are at that point
	 *
	 * @throws com.meliorbis.economics.model.ModelException If there is an error in the calculation
	 * 
	 * @param  The type of shock, should be Double for continuous shocks or Integer for discrete 
	 * shocks
	 */
	@Override
	final public  double[] calculateAggregateControls(SimState simState_, DoubleArray individualTransitionByAggregateControl_,
			double[] currentAggStates_, MultiDimensionalArray priorAggShockIndices_, S calcState_) throws ModelException
	{
		Stoppable timer = new Timer().start("calculateAggregateControls");
		
		DoubleArray determinants = calculateControlDeterminants(simState_, individualTransitionByAggregateControl_, currentAggStates_, priorAggShockIndices_, calcState_);
		
		// For the case that the controls can be calculated directly
		if(determinants.numberOfElements() == _config.getAggregateControlCount()) {
			return determinants.toArray();
		}
		
		double[] controls;
		
		if(_expnDims.length == 0) 
		{	
			
			controls = calcControls(determinants, 
						_config.getAggregateControls().toArray(
								new DoubleArray[_config.getAggregateControlCount()]), 
						_config.getControlTargets().toArray(
								new DoubleArray[_config.getAggregateControlCount()]));
		}
		else
		{
			
		
			DoubleArray[] ctrlSlices = IntStream.of(_controlDims).mapToObj(
							(idx)->determinants.lastDimSlice(idx)).toArray(_arrayCreator);
			
			DoubleArray[] expnSlices = IntStream.of(_expnDims).mapToObj(
							(idx)->determinants.lastDimSlice(idx)).toArray(_arrayCreator);
			
			DoubleArray restrictedDet = 
					ctrlSlices[0].stack((DoubleArray[]) 
							ArrayUtils.subarray(ctrlSlices,1,ctrlSlices.length)).at(_ctrlSelector);
			
			DoubleArray expnDet = 
					expnSlices[0].stack((DoubleArray[]) 
							ArrayUtils.subarray(expnSlices,1,expnSlices.length)).at(_ctrlSelector);
			
			
			double[] controlsOnly = calcControls(restrictedDet,_restrictedControls, _restrictedTargets);
			
			DoubleArray expns = new DoubleGridFunctionFactory().createFunction(
					Arrays.asList(_restrictedControls),
					expnDet).callWithDouble(controlsOnly);
			
			controls = new double[_config.getAggregateControlCount()];
			
			int index = 0;
			int ctrlsIndex = 0;
			int expnIndex = 0;
			while(index < controls.length)
			{
				if(ctrlsIndex < _controlDims.length && _controlDims[ctrlsIndex] == index)
				{
					controls[index++] = controlsOnly[ctrlsIndex++]; 
				}
				else
				{
					controls[index++] = expns.get(expnIndex++); 
				}
			}
		}
		
		timer.stop();
		afterControlsCalculated(controls, 
				currentAggStates_, 
				priorAggShockIndices_, 
				calcState_);
		
		return controls;
	}

	protected double[] calcControls(DoubleArray determinants_, DoubleArray[] controlGrid_, DoubleArray[] targets_) throws NumericsException
	{
		DoubleArray detAdjusted = determinants_.copy();
		   
		for (int i = 0; i < targets_.length; i++)
		{
			// Subtract the appropriate inputs from each output to turn the problem into a rootfinding one
			(detAdjusted.numberOfDimensions() == 1? detAdjusted : detAdjusted.lastDimSlice(i)).modifying().across(i).subtract(targets_[i]);
		}
		
		DoubleArray controlVal = DoubleArrayFunctions.findRoot(
				detAdjusted,
				
				controlGrid_);
		

		
		double[] controls = controlVal.toArray();
		return controls;
	}
	
	protected  void afterControlsCalculated(double[] controls_, 
			double[] currentAggStates_, 
			MultiDimensionalArray priorAggShockIndices_, 
			S calcState_) throws ModelException
	{
		
	}
	
	/**
	 * Calculates the control determinants given the simulation state.
	 * 
	 * @param simState_ The simulation state
	 * @param individualTransitionByAggregateControl_ The individual transition for that state, but conditional on the control
	 * @param currentAggStates_ The current aggregate states
	 * @param priorAggShockIndices_ The current shock indexes
	 * @param calcState_ The calculation state
	 * 
	 * @return An array which yields the control determinants conditional on assumption of each value of the controls
	 * 
	 * @throws ModelException  If there are errors performing the calculation
	 * 
	 * @param  The numeric type of shocks used in this model
	 */
	final protected   DoubleArray 
				calculateControlDeterminants(SimState simState_, 
						DoubleArray individualTransitionByAggregateControl_,
						double[] currentAggStates_, 
						MultiDimensionalArray priorAggShockIndices_, 
						S calcState_) throws ModelException
	{
		if(priorAggShockIndices_ instanceof DoubleArray) {
			return calculateControlDeterminants(simState_, individualTransitionByAggregateControl_, currentAggStates_, (DoubleArray)priorAggShockIndices_, calcState_);
		}
		return calculateControlDeterminants(simState_, individualTransitionByAggregateControl_, currentAggStates_, (IntegerArray)priorAggShockIndices_, calcState_);
	}
	
	/**
	 * Calculates the control determinants given the simulation state when the model is simulated with continuous shocks.
	 * 
	 * This implementation will throw an exception and must be overridden by models simulated with continuous shocks.
	 * 
	 * @param simState_ The simulation state
	 * @param individualTransitionByAggregateControl_ The individual transition for that state, but conditional on the control
	 * @param currentAggStates_ The current aggregate states
	 * @param priorAggShockIndices_ The current shock indexes
	 * @param calcState_ The calculation state
	 * 
	 * @return An array which yields the control determinants conditional on assumption of each value of the controls
	 * 
	 * @throws ModelException  If there are errors performing the calculation
	 */
	protected DoubleArray calculateControlDeterminants(SimState simState_, 
			DoubleArray individualTransitionByAggregateControl_,
			double[] currentAggStates_, DoubleArray priorAggShockIndices_, S calcState_) throws ModelException
	{
		throw new NotImplementedException("This method must be implemented for models that are simulated with continuous shocks.");
	}
	
	/**
	 * Calculates the control determinants given the simulation state when the model is simulated with discrete shocks.
	 * 
	 * This implementation will throw an exception and must be overridden by models simulated with discrete shocks.
	 * 
	 * @param simState_ The simulation state
	 * @param individualTransitionByAggregateControl_ The individual transition for that state, but conditional on the control
	 * @param currentAggStates_ The current aggregate states
	 * @param priorAggShockIndices_ The current shock indexes
	 * @param calcState_ The calculation state
	 * 
	 * @return An array which yields the control determinants conditional on assumption of each value of the controls
	 * 
	 * @throws ModelException  If there are errors performing the calculation
	 */
	protected DoubleArray calculateControlDeterminants(SimState simState_, 
			DoubleArray individualTransitionByAggregateControl_,
			double[] currentAggStates_, IntegerArray priorAggShockIndices_, S calcState_) throws ModelException
	{
		throw new NotImplementedException("This method must be implemented for models that are simulated with discrete shocks.");
	}


	@Override
	public void adjustExpectedAggregates(S state_)
	{			
		// The superclass can still handle the states
		super.adjustExpectedAggregates(state_);
		
		DoubleArray currentControlsPolicy = state_.getCurrentControlsPolicy();
	
		// NOTE: this is a bit circular because it uses the old expected controls. But since
		// the controls policy does not vary in the controls dimension it is not a problem
		DoubleArray newExpectedControls =  conditionalExpectation(currentControlsPolicy, state_, false);
		
		for(int i = 0; i< _config.getAggregateControlCount(); i++)
		{
			newExpectedControls.lastDimSlice(i).modifying().map(
					DoubleArrayFunctions.cutToBounds(_config.getAggregateControls().get(i).first(), _config.getAggregateControls().get(i).last()));	
		}
		((StateWithControls) state_).setExpectedAggregateControls(newExpectedControls);
	}
	
	
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy