All Downloads are FREE. Search and download functionalities are using the official Maven repository.

repicea.math.ProductFunctionWrapper Maven / Gradle / Ivy

There is a newer version: 1.4.3
Show newest version
package repicea.math;

import java.security.InvalidParameterException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import repicea.math.utility.MatrixUtility;

@SuppressWarnings("serial")
public class ProductFunctionWrapper extends AbstractMathematicalFunction {

	private final List originalFunctions;
	private final Map> parameterMap;
	private final Map> variableMap;
	
	/**
	 * Constructor.
	 * @param wrappedOriginalFunctions a series of InternalMathematicalFunctionWrapper instances 
	 */
	public ProductFunctionWrapper(InternalMathematicalFunctionWrapper... wrappedOriginalFunctions) {
		this();
		initialize(wrappedOriginalFunctions);
	}

	protected void initialize(InternalMathematicalFunctionWrapper... wrappedOriginalFunctions) {
		if (wrappedOriginalFunctions == null || wrappedOriginalFunctions.length < 2) {
			throw new InvalidParameterException("There must be at least two instances of InternalMathematicalFunctionWrapper in the arguments of the constructor!");
		}
		List newParameterIndex = new ArrayList();
		List newVariableIndex = new ArrayList();

		for (InternalMathematicalFunctionWrapper originalFunction : wrappedOriginalFunctions) {
			originalFunctions.add(originalFunction);
		}

		for (InternalMathematicalFunctionWrapper w : originalFunctions) {
			for (Integer i : w.getNewParameterIndices()) {
				if (!newParameterIndex.contains(i)) {
					newParameterIndex.add(i);
				}
			}
			for (Integer i : w.getNewVariableIndices()) {
				if (!newVariableIndex.contains(i)) {
					newVariableIndex.add(i);
				}
			}
		}
		
		Collections.sort(newParameterIndex);
		if (newParameterIndex.get(newParameterIndex.size() - 1) != newParameterIndex.size() - 1) {
			throw new InvalidParameterException("The new parameter indices are inconsistent!");
		}
		Collections.sort(newVariableIndex);
		if (newVariableIndex.get(newVariableIndex.size() - 1) != newVariableIndex.size() - 1) {
			throw new InvalidParameterException("The new variable indices are inconsistent!");
		}
		for (Integer i : newParameterIndex) {
			for (InternalMathematicalFunctionWrapper f : originalFunctions) {
				if (f.getNewParameterIndices().contains(i)) {
					if (!parameterMap.containsKey(i)) {
						parameterMap.put(i, new ArrayList());
					}
					parameterMap.get(i).add(f);
				}
			}
		}

		for (Integer i : newVariableIndex) {
			for (InternalMathematicalFunctionWrapper f : originalFunctions) {
				if (f.getNewVariableIndices().contains(i)) {
					if (!variableMap.containsKey(i)) {
						variableMap.put(i, new ArrayList());
					}
					variableMap.get(i).add(f);
				}
			}
		}
	}
	
	/**
	 * Constructor for derived classes.
	 */
	public ProductFunctionWrapper() {
		parameterMap = new HashMap>();
		variableMap = new HashMap>();
		originalFunctions = new ArrayList();
	}
	
	
	@Override
	public Double getValue() {
		double value = 1d;
		for (InternalMathematicalFunctionWrapper w : originalFunctions) {
			double wValue = w.getValue();
			if (wValue < 0 || Double.isNaN(wValue)) {
				int u = 0;
			}
			value *= wValue;
		}
		return value;
	}

	
	@Override
	public Matrix getGradient() {
		return getGradientFromTheseInternalWrapper(originalFunctions);
	}
	
	private Matrix getGradientFromTheseInternalWrapper(List wrappers) {
		Matrix gradient = new Matrix(getNumberOfParameters(), 1);
		for (InternalMathematicalFunctionWrapper w : wrappers) {
			double multiplier = 1d;
			for (InternalMathematicalFunctionWrapper w2 : wrappers) {
				if (!w2.equals(w)) {
					multiplier *= w2.getValue();
				}
			}
			Matrix wGradient = reformateGradient(w);
			gradient = gradient.add(wGradient.scalarMultiply(multiplier));
//			MatrixUtility.add(gradient, wGradient.scalarMultiply(multiplier));
		}
		return gradient;
	}

	private Matrix reformateGradient(InternalMathematicalFunctionWrapper w) {
		Matrix gradientTmp = new Matrix(getNumberOfParameters(), 1);
		Matrix wGradient = w.getGradient();
		if (wGradient != null) {	// the function has no parameter
			for (int i = 0; i < wGradient.m_iRows; i++) {
				gradientTmp.setValueAt(w.reverseParmMap.get(i), 0, wGradient.getValueAt(i, 0));
			}
		}
		return gradientTmp;
	}
	
	private SymmetricMatrix reformateHessian(InternalMathematicalFunctionWrapper w) {
		SymmetricMatrix hessianTmp = new SymmetricMatrix(getNumberOfParameters());
		SymmetricMatrix wHessian = w.getHessian();
		if (wHessian != null) {		// the function has no parameter
			for (int i = 0; i < wHessian.m_iRows; i++) {
				for (int j = i; j < wHessian.m_iRows; j++) {
					hessianTmp.setValueAt(w.reverseParmMap.get(i), w.reverseParmMap.get(j), wHessian.getValueAt(i, j));
//					if (i !=  j) {
//						hessianTmp.setValueAt(w.reverseParmMap.get(j), w.reverseParmMap.get(i), wHessian.getValueAt(j, i));
//					}
				}
			}
		}
		return hessianTmp;
	}

	
	@Override
	public SymmetricMatrix getHessian() {
		Matrix hessian = new Matrix(getNumberOfParameters(), getNumberOfParameters());
		for (InternalMathematicalFunctionWrapper w : originalFunctions) {
			List wrappersOtherThanW = getWrapperListWithoutThisOne(w);
			double multiplier = 1d;
			for (InternalMathematicalFunctionWrapper w2 : wrappersOtherThanW) {
				multiplier *= w2.getValue();
			}
			Matrix theirGradient = getGradientFromTheseInternalWrapper(wrappersOtherThanW);
			Matrix gradientPart = reformateGradient(w).multiply(theirGradient.transpose());
			SymmetricMatrix hessianPart = reformateHessian(w).scalarMultiply(multiplier);
//			MatrixUtility.add(hessian, hessianPart.add(gradientPart));
			hessian = hessian.add(hessianPart.add(gradientPart));
		}
		return SymmetricMatrix.convertToSymmetricIfPossible(hessian);
	}
	
	private List getWrapperListWithoutThisOne(InternalMathematicalFunctionWrapper w) {
		List wrappers = new ArrayList(originalFunctions);
		wrappers.remove(w);
		return wrappers;
	}

	
	
	@Override
	public final void setParameterValue(int parameterIndex, double parameterValue) {
		if (!parameterMap.containsKey(parameterIndex)) {
			throw new InvalidParameterException("The parameter index is invalid!");
		}
		for (InternalMathematicalFunctionWrapper w : parameterMap.get(parameterIndex)) {
			w.setParameterValue(parameterIndex, parameterValue);
		}
	}

	@Override
	public final double getParameterValue(int parameterIndex) {
		if (!parameterMap.containsKey(parameterIndex)) {
			throw new InvalidParameterException("The parameter index is invalid!");
		}
		return parameterMap.get(parameterIndex).get(0).getParameterValue(parameterIndex);
	}

	@Override
	public final void setVariableValue(int variableIndex, double variableValue) {
		if (!variableMap.containsKey(variableIndex)) {
			throw new InvalidParameterException("The variable index is invalid!");
		}
		for (InternalMathematicalFunctionWrapper w : variableMap.get(variableIndex)) {
			w.setVariableValue(variableIndex, variableValue);
		}
	}

	@Override
	public final double getVariableValue(int variableIndex) {
		if (!variableMap.containsKey(variableIndex)) {
			throw new InvalidParameterException("The variable index is invalid!");
		}
		return variableMap.get(variableIndex).get(0).getVariableValue(variableIndex);
	}

	@Override
	public final int getNumberOfParameters() {
		return parameterMap.size();
	}

	@Override
	public final int getNumberOfVariables() {
		return variableMap.size();
	}

	
	@Override
	public final void setVariables(Matrix xVector) {
		super.setVariables(xVector);
	}
	
	@Override
	public final void setParameters(Matrix beta) {
		super.setParameters(beta);
	}
	
	@Override
	public final Matrix getParameters() {
		return super.getParameters();
	}

	@Override
	public void setBounds(int parameterIndex, ParameterBound bound) {
		if (!parameterMap.containsKey(parameterIndex)) {
			throw new InvalidParameterException("The parameter index is not valid!");
		} else {
			for (InternalMathematicalFunctionWrapper w : parameterMap.get(parameterIndex)) {
				w.setBounds(parameterIndex, bound);
			}
		}
	}
	
	@Override
	public boolean isThisParameterValueWithinBounds(int parameterIndex, double parameterValue) {
		if (parameterMap.containsKey(parameterIndex)) {
			for (InternalMathematicalFunctionWrapper w : parameterMap.get(parameterIndex)) {
				if (!w.isThisParameterValueWithinBounds(parameterIndex, parameterValue)) {
					return false;
				}
			}
		}
		return true;
	}
	
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy