com.enterprisemath.math.statistics.NormalDistributionMixtureEstimator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of em-math Show documentation
Show all versions of em-math Show documentation
Advanced mathematical algorithms.
The newest version!
package com.enterprisemath.math.statistics;
import org.apache.commons.lang3.builder.ToStringBuilder;
import com.enterprisemath.math.algebra.Interval;
import com.enterprisemath.math.probability.NormalDistribution;
import com.enterprisemath.math.probability.NormalDistributionMixture;
import com.enterprisemath.math.statistics.observation.ObservationIterator;
import com.enterprisemath.math.statistics.observation.ObservationProvider;
import com.enterprisemath.utils.ValidationUtils;
/**
* This class is responsible for estimating the normal distribution mixture.
* Purpose of this class is to let the user to set high level limits and let the algorithm to do the rest.
* The whole algorithm works in the way that whole estimation starts with one component.
* Every iteration one component is selected and split. Then classic EM algorithm is invoked.
*
* @author radek.hecl
*
*/
public class NormalDistributionMixtureEstimator implements Estimator {
/**
* Builder object.
*/
public static class Builder {
/**
* Maximum allowed number of components.
*/
private Integer maxComponents = 20;
/**
* The minimum weight which is allowed for a component during estimation.
* Every component which has less than the minWeight will be removed.
* This essentially means that group of observations statistically less important than minWeight
* might (but not necessary are) be ignored.
* Value must be in interval [0, 1).
*/
private Double minWeight = 0.05;
/**
* Minimum allowed sigma.
*/
private Double minSigma = 0.01;
/**
* Step listener.
*/
private EstimatorStepListener stepListener = EmptyEstimatorStepListener.create();
/**
* Sets maximum allowed number of components.
*
* @param maxComponents maximum of number of components in the result
* @return this instance
*/
public Builder setMaxComponents(int maxComponents) {
this.maxComponents = maxComponents;
return this;
}
/**
* Sets minimal weight for the components.
* Every component with weight less than minWeight will be removed from the result.
* This essentially means that group of observations statistically less important than minWeight
* might (but not necessary are) be ignored.
* Value must be in interval [0, 1).
*
* @param minWeight component minimal weight value
* @return this instance
*/
public Builder setMinWeight(double minWeight) {
this.minWeight = minWeight;
return this;
}
/**
* Sets minimal allowed sigma for all components.
*
* @param minSigma minimal allowed sigma for all components
* @return this instance
*/
public Builder setMinSigma(Double minSigma) {
this.minSigma = minSigma;
return this;
}
/**
* Sets step listener.
*
* @param stepListener step listener
* @return this instance
*/
public Builder setStepListener(EstimatorStepListener stepListener) {
this.stepListener = stepListener;
return this;
}
/**
* Builds the result object.
*
* @return created object
*/
public NormalDistributionMixtureEstimator build() {
return new NormalDistributionMixtureEstimator(this);
}
}
/**
* Maximum allowed number of components.
*/
private Integer maxComponents;
/**
* The minimum weight which is allowed for a component during estimation.
* Every component which has less than the minWeight will be removed.
* This essentially means that group of observations statistically less important than minWeight
* might (but not necessary are) be ignored.
* Value must be in interval [0, 1).
*/
private Double minWeight;
/**
* Minimum allowed sigma.
*/
private Double minSigma;
/**
* Step listener.
*/
private EstimatorStepListener stepListener;
/**
* Creates new instance.
*
* @param builder builder object
*/
public NormalDistributionMixtureEstimator(Builder builder) {
minWeight = builder.minWeight;
maxComponents = builder.maxComponents;
minSigma = builder.minSigma;
stepListener = builder.stepListener;
guardInvariants();
}
/**
* Guards this object to be consistent. Throws exception if this is not the case.
*/
private void guardInvariants() {
ValidationUtils.guardPositiveInt(maxComponents, "maxComponents must be positive");
ValidationUtils.guardNotNegativeDouble(minWeight, "minWeight cannot be negative");
ValidationUtils.guardGreaterDouble(1, minWeight, "minWeight must be less than 1");
ValidationUtils.guardPositiveDouble(minSigma, "minSigma must be positive");
ValidationUtils.guardNotNull(stepListener, "stepListener cannot be null");
}
@Override
public NormalDistributionMixture estimate(ObservationProvider observations) {
Interval minMax = extractInterval(observations);
ValidationUtils.guardGreaterOrEqualDouble(minMax.getMin(), -1000000,
"observation interval is out of range for calcualtion");
ValidationUtils.guardGreaterOrEqualDouble(1000000, minMax.getMax(),
"observation interval is out of range for calcualtion");
NormalDistributionMixture res = initializeOneCompoenent(observations);
stepListener.stepDone(res);
double resL = Double.NEGATIVE_INFINITY;
double newL = countLnL(observations, res);
int iteration = 0;
while (newL - resL > 0.01 && res.getNumComponents() < maxComponents && iteration < 100) {
++iteration;
resL = newL;
// find the maximum weight
double splitValue = 0;
int splitIdx = 0;
for (int i = 0; i < res.getNumComponents(); ++i) {
if (res.getWeights().get(i) * res.getComponents().get(i).getSigma() > splitValue) {
splitValue = res.getWeights().get(i) * res.getComponents().get(i).getSigma();
splitIdx = i;
}
}
// split the component with highest weight
NormalDistributionMixture.Builder builder = new NormalDistributionMixture.Builder();
for (int i = 0; i < res.getNumComponents(); ++i) {
if (i == splitIdx) {
NormalDistribution comp = res.getComponents().get(i);
builder.addComponent(res.getWeights().get(i) / 2, comp.getMi() + comp.getSigma() / 2, comp.getSigma());
builder.addComponent(res.getWeights().get(i) / 2, comp.getMi() - comp.getSigma() / 2, comp.getSigma());
}
else {
if (res.getWeights().get(i) >= minWeight) {
builder.addComponent(res.getWeights().get(i), res.getComponents().get(i));
}
}
}
NormalDistributionMixture newMixture = builder.build();
// iterations for the new mixture
double help = Double.NEGATIVE_INFINITY;
int emiteration = 0;
while (emiteration < 5 || newL - help > 0.01) {
++emiteration;
help = newL;
newMixture = nextIteration(observations, newMixture);
newL = countLnL(observations, newMixture);
if (getMinWeigth(newMixture) < minWeight) {
newMixture = newMixture.createSignificantComponentMixture(minWeight);
newL = countLnL(observations, newMixture);
}
// assign new L value if possible
if (newL > resL) {
res = newMixture;
stepListener.stepDone(res);
}
}
}
stepListener.stepDone(res);
return res;
}
/**
* Extracts interval from the specified observations.
*
* @param observations observations
* @return extracted interval
*/
private Interval extractInterval(ObservationProvider observations) {
ObservationIterator iterator = observations.getIterator();
Interval.Builder res = new Interval.Builder();
while (iterator.isNextAvailable()) {
res.addPoint(iterator.getNext());
}
return res.build();
}
/**
* Makes the one component initialization for the EM algorithm.
*
* @param observations observation for which the initialization should be calculated
* @return mixture with initial parameters
*/
private NormalDistributionMixture initializeOneCompoenent(ObservationProvider observations) {
// calculates first and second central momentum
double m1 = 0;
double m2 = 0;
ObservationIterator iterator = observations.getIterator();
while (iterator.isNextAvailable()) {
double x = iterator.getNext();
m1 += x;
m2 += x * x;
}
m1 /= iterator.getNumIterated();
m2 /= iterator.getNumIterated();
m2 -= m1 * m1;
double sigma = Math.sqrt(m2);
if (sigma < minSigma) {
sigma = minSigma;
}
// determine the components
return new NormalDistributionMixture.Builder().
addComponent(1, NormalDistribution.create(m1, sigma)).
build();
}
/**
* Makes one iteration of the EM algorithm. Returns the mixture after the iteration.
*
* @param observations observation for which the iteration should be calculated
* @param start starting position
* @return mixture after the iteration
*/
private NormalDistributionMixture nextIteration(ObservationProvider observations, NormalDistributionMixture start) {
//
int numComponents = start.getNumComponents();
//double L = 0;
double q = 0;
double[] w = new double[numComponents];
double[] mi = new double[numComponents];
double[] sigma = new double[numComponents];
// iteration over all observations
ObservationIterator iterator = observations.getIterator();
while (iterator.isNextAvailable()) {
double x = iterator.getNext();
for (int j = 0; j < numComponents; ++j) {
q = start.getWeights().get(j) * start.getComponents().get(j).getValue(x) / start.getValue(x);
if (Double.isNaN(q)) {
q = 0;
}
w[j] += q;
mi[j] += x * q;
sigma[j] += x * x * q;
}
//L += Math.log(getValue(samples[i]));
}
// finalizing the calculation
for (int i = 0; i < numComponents; ++i) {
mi[i] = mi[i] / w[i];
sigma[i] = Math.sqrt(-mi[i] * mi[i] + sigma[i] / w[i]);
if (Double.valueOf(sigma[i]).equals(Double.NaN)) {
sigma[i] = Double.MIN_VALUE;
}
w[i] = w[i] / iterator.getNumIterated();
}
//L /= samples.length;
// creating new instance
NormalDistributionMixture.Builder builder = new NormalDistributionMixture.Builder();
for (int i = 0; i < numComponents; ++i) {
builder.addComponent(w[i], new NormalDistribution.Builder().
setMi(mi[i]).
setSigma(Math.max(minSigma, sigma[i])).
build());
}
return builder.build();
}
/**
* Calculates the ln(L) value.
* Where L = prod_x( sum_i(w(i|x)P(i|x)) ) and ln(L) = sum_x( ln(sum_i(w(i|x)P(i|x))) ).
*
* @param observations observations
* @param mixture mixture
* @return ln(L) value
*/
private double countLnL(ObservationProvider observations, NormalDistributionMixture mixture) {
double res = 0;
ObservationIterator iterator = observations.getIterator();
while (iterator.isNextAvailable()) {
double x = iterator.getNext();
res += mixture.getLnValue(x);
}
return res;
}
/**
* Returns minimal weight.
*
* @param mixture mixture
* @return minimal weight
*/
private double getMinWeigth(NormalDistributionMixture mixture) {
double res = 1;
for (double w : mixture.getWeights()) {
if (w < res) {
res = w;
}
}
return res;
}
@Override
public String toString() {
return ToStringBuilder.reflectionToString(this);
}
}