All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.meliorbis.economics.aggregate.AggregateSolverBase Maven / Gradle / Ivy

package com.meliorbis.economics.aggregate;

import com.meliorbis.economics.infrastructure.SolverBase;
import com.meliorbis.economics.infrastructure.notifications.ArrayObserver;
import com.meliorbis.economics.infrastructure.notifications.Notifier;
import com.meliorbis.economics.model.Model;
import com.meliorbis.economics.model.ModelConfig;
import com.meliorbis.economics.model.State;
import com.meliorbis.economics.model.StateWithControls;
import com.meliorbis.numerics.generic.primitives.DoubleArray;
import com.meliorbis.numerics.generic.primitives.DoubleBinaryOp;
import com.meliorbis.numerics.generic.primitives.impl.DoubleArrayFunctions;
import com.meliorbis.utils.Pair;

/**
 * Base class for aggregate solvers that handles some of the boilerplate for updating the aggregate transition, leaving the
 * primary calculation to subclasses
 * 
 * @author Tobias Grasl
 *
 * @param  The configuration class in use
 * @param  The state class in use
 * @param  The model class in use
 */
public abstract class AggregateSolverBase, M extends Model> extends
		SolverBase implements AggregateProblemSolver
{
	private double _newWeight = 1d;
	private boolean _constrainToGrid = true;
	private Notifier _transitionNotifier = new Notifier();
	
	public AggregateSolverBase(M model_, C config_)
	{
		super(model_, config_);
	}
	
	@Override
	public void initialise(S state_)
	{
		// Nothing to do by default
		state_.setAggregateError(Double.POSITIVE_INFINITY);
	}

	/**
	 * Register a listener that is notified once the aggregate transition has been updated
	 * 
	 * @param listener_ The listener to register
	 */
	public void addTransitionListener(ArrayObserver listener_)
	{
		_transitionNotifier.registerListener(listener_);
	}

	/**
	 * Updates the aggregate transition and control policy based on the provided state. The actual calculation is
	 * delegated to the subclass, but this method will apply the {@code this.newWeight} set on the instance and also constrain the
	 * transition to the grid if {@code this.constrainToGrid} is true. Listeners are notified after the state has been updated.
	 * 
	 * @param state_ The current calculation state, which will be updated with the newly calculated values before listeners are notified
	 */
	@SuppressWarnings("unchecked")
	@Override
	final public void updateAggregateTransition(S state_)
	{
		// get a copy of the old one
		DoubleArray oldTransition = state_.getAggregateTransition();
				
		// Derive the new transition
		Pair, DoubleArray> results = calculateAggregatePolicies(state_);
		
		DoubleArray newTransition = results.getLeft();
		
		// if constrainToGrid is true, make sure all the values are within range of the aggregate
		if(_constrainToGrid) 
		{
			for(int i = 0; i < _config.getAggregateEndogenousStateCount(); i++) 
			{
				newTransition.lastDimSlice(i).modifying().map(
						DoubleArrayFunctions.cutToBounds(
								_config.getAggregateEndogenousStates().get(i).min(), 
								_config.getAggregateEndogenousStates().get(i).max()));
			}
		}
		
		updateError(oldTransition, newTransition, state_);
		
		final DoubleBinaryOp weightedMean = (DoubleBinaryOp)(newVal, old)-> newVal *_newWeight + old*(1d-_newWeight);
		newTransition.modifying().with(oldTransition).map(weightedMean);
		
		// Update the state
		state_.setAggregateTransition(newTransition);
		
		
		if(results.getRight() != null) {
			DoubleArray newPolicy = results.getRight();
			DoubleArray oldPolicy = ((StateWithControls)state_).getCurrentControlsPolicy();
			
			newPolicy.modifying().with(oldPolicy).map(weightedMean);
			
			((StateWithControls)state_).setCurrentControlsPolicy(newPolicy);
		}
		
		// call the hook for post-processing
		_transitionNotifier.changed(oldTransition, newTransition, state_);
	}



	/**
	 * Updates the aggregate solution error on the state given the old and new transition.
	 * 
	 * @param oldTransition_ The transition as it was prior to this iteration
	 * @param newTransition_ The updated transition after this iteration
	 * @param state_ The calculation state
	 */
	protected void updateError(DoubleArray oldTransition_, DoubleArray newTransition_, S state_)
	{
		state_.setAggregateError(DoubleArrayFunctions.maximumRelativeDifference(newTransition_, oldTransition_));
	}

	/**
	 * The newWeight determines the damping factor applied to aggregate transition updates. The transition is updated as a weighted mean
	 * of the pior and newly derived one, using the formula (1-newWeight)*prior + newWeight*new
	 * 
	 * @param newWeight_ The weight of the newly derived rule in aggregate transition update
	 */
	final protected void setNewWeight(double newWeight_)
	{
		_newWeight = newWeight_;
	}

	/**
	 * Causes the solver to truncate values in the aggregate transition calculated to be on the grid of the aggregates if true
	 * 
	 * @param constrainToGrid_ Indicates whether or not to truncate
	 */
	final protected void setConstrainToGrid(boolean constrainToGrid_)
	{
		_constrainToGrid = constrainToGrid_;
	}

	/**
	 * This method calculates the new aggregate transition and control policy and returns them as the left and right
	 * members of the pair respectively.
	 * 
	 * @param state_ The current calculation state
	 * 
	 * @return A pair of arrays where the left member is the new aggregate state transition and the right member the 
	 * current controls policy, which may be null.
	 */
	protected abstract Pair, DoubleArray> calculateAggregatePolicies(final S state_);
}