com.enterprisemath.math.statistics.DiagonalNormalDistributionMixtureEstimator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of em-math Show documentation
Show all versions of em-math Show documentation
Advanced mathematical algorithms.
The newest version!
package com.enterprisemath.math.statistics;
import org.apache.commons.lang3.builder.ToStringBuilder;
import com.enterprisemath.math.algebra.Hypercube;
import com.enterprisemath.math.algebra.Vector;
import com.enterprisemath.math.probability.DiagonalNormalDistribution;
import com.enterprisemath.math.probability.DiagonalNormalDistributionMixture;
import com.enterprisemath.math.statistics.observation.ObservationIterator;
import com.enterprisemath.math.statistics.observation.ObservationProvider;
import com.enterprisemath.utils.ValidationUtils;
/**
* This class is responsible for estimating the diagonal normal distribution mixture.
* Purpose of this class is to let the user to set high level limits and let the algorithm to do the rest.
* The whole algorithm works in the way that whole estimation starts with one component.
* Every iteration one component is selected and split. Then classic EM algorithm is invoked.
*
* @author radek.hecl
*
*/
public class DiagonalNormalDistributionMixtureEstimator implements Estimator {
/**
* Constant for truncating components.
*/
private static final double COMP_TRUNC = Math.log(Double.MIN_VALUE);
/**
* Builder object.
*/
public static class Builder {
/**
* Maximum allowed number of components.
*/
private Integer maxComponents = 20;
/**
* The minimum weight which is allowed for a component during estimation.
* Every component which has less than the minWeight will be removed.
* This essentially means that group of observations statistically less important than minWeight
* might (but not necessary are) be ignored.
* Value must be in interval [0, 1).
*/
private Double minWeight = 0.05;
/**
* Minimum allowed sigma.
*/
private Double minSigma = 0.01;
/**
* Step listener.
*/
private EstimatorStepListener stepListener = EmptyEstimatorStepListener.create();
/**
* Sets maximum allowed number of components.
*
* @param maxComponents maximum of number of components in the result
* @return this instance
*/
public Builder setMaxComponents(int maxComponents) {
this.maxComponents = maxComponents;
return this;
}
/**
* Sets minimal weight for the components.
* Every component with weight less than minWeight will be removed from the result.
* This essentially means that group of observations statistically less important than minWeight
* might (but not necessary are) be ignored.
* Value must be in interval [0, 1).
*
* @param minWeight component minimal weight value
* @return this instance
*/
public Builder setMinWeight(double minWeight) {
this.minWeight = minWeight;
return this;
}
/**
* Sets minimal allowed sigma for all components.
*
* @param minSigma minimal allowed sigma for all components
* @return this instance
*/
public Builder setMinSigma(Double minSigma) {
this.minSigma = minSigma;
return this;
}
/**
* Sets step listener.
*
* @param stepListener step listener
* @return this instance
*/
public Builder setStepListener(EstimatorStepListener stepListener) {
this.stepListener = stepListener;
return this;
}
/**
* Builds the result object.
*
* @return created object
*/
public DiagonalNormalDistributionMixtureEstimator build() {
return new DiagonalNormalDistributionMixtureEstimator(this);
}
}
/**
* Maximum allowed number of components.
*/
private Integer maxComponents;
/**
* The minimum weight which is allowed for a component during estimation.
* Every component which has less than the minWeight will be removed.
* This essentially means that group of observations statistically less important than minWeight
* might (but not necessary are) be ignored.
* Value must be in interval [0, 1).
*/
private Double minWeight;
/**
* Minimum allowed sigma.
*/
private Double minSigma;
/**
* Step listener.
*/
private EstimatorStepListener stepListener;
/**
* Creates new instance.
*
* @param builder builder object
*/
public DiagonalNormalDistributionMixtureEstimator(Builder builder) {
minWeight = builder.minWeight;
maxComponents = builder.maxComponents;
minSigma = builder.minSigma;
stepListener = builder.stepListener;
guardInvariants();
}
/**
* Guards this object to be consistent. Throws exception if this is not the case.
*/
private void guardInvariants() {
ValidationUtils.guardPositiveInt(maxComponents, "maxComponents must be positive");
ValidationUtils.guardNotNegativeDouble(minWeight, "minWeight cannot be negative");
ValidationUtils.guardGreaterDouble(1, minWeight, "minWeight must be less than 1");
ValidationUtils.guardPositiveDouble(minSigma, "minSigma must be positive");
ValidationUtils.guardNotNull(stepListener, "stepListener cannot be null");
}
@Override
public DiagonalNormalDistributionMixture estimate(ObservationProvider observations) {
Hypercube minMax = extractHypercube(observations);
for (int i = 0; i < minMax.getDimension(); ++i) {
ValidationUtils.guardGreaterOrEqualDouble(minMax.getMin().getComponent(i), -1000000,
"observation is out of range for calcualtion");
ValidationUtils.guardGreaterOrEqualDouble(1000000, minMax.getMax().getComponent(i),
"observation is out of range for calcualtion");
}
DiagonalNormalDistributionMixture res = initializeOneCompoenent(observations, minMax.getDimension());
stepListener.stepDone(res);
double resL = Double.NEGATIVE_INFINITY;
double newL = countLnL(observations, res);
int iteration = 0;
while (newL - resL > 0.01 && res.getNumComponents() < maxComponents && iteration < 100) {
++iteration;
resL = newL;
// find the maximum weight
double splitValue = 0;
int splitCompIdx = 0;
int splitDimIdx = 0;
for (int i = 0; i < res.getNumComponents(); ++i) {
DiagonalNormalDistribution comp = res.getComponents().get(i);
for (int j = 0; j < comp.getDimension(); ++j) {
if (res.getWeights().get(i) * res.getComponents().get(i).getSigma().getComponent(j) > splitValue) {
splitValue = res.getWeights().get(i) * res.getComponents().get(i).getSigma().getComponent(j);
splitCompIdx = i;
splitDimIdx = j;
}
}
}
// split the component with highest weight
DiagonalNormalDistributionMixture.Builder builder = new DiagonalNormalDistributionMixture.Builder();
for (int i = 0; i < res.getNumComponents(); ++i) {
if (i == splitCompIdx) {
DiagonalNormalDistribution comp = res.getComponents().get(i);
double[] mi1 = new double[res.getDimension()];
double[] mi2 = new double[res.getDimension()];
for (int j = 0; j < comp.getDimension(); ++j) {
if (j == splitDimIdx) {
mi1[j] = comp.getMi().getComponent(j) + comp.getSigma().getComponent(j) / 2;
mi2[j] = comp.getMi().getComponent(j) - comp.getSigma().getComponent(j) / 2;
}
else {
mi1[j] = comp.getMi().getComponent(j);
mi2[j] = comp.getMi().getComponent(j);
}
}
builder.addComponent(res.getWeights().get(i) / 2, Vector.create(mi1), comp.getSigma());
builder.addComponent(res.getWeights().get(i) / 2, Vector.create(mi2), comp.getSigma());
}
else {
if (res.getWeights().get(i) >= minWeight) {
builder.addComponent(res.getWeights().get(i), res.getComponents().get(i));
}
}
}
DiagonalNormalDistributionMixture newMixture = builder.build();
// iterations for the new mixture
double help = Double.NEGATIVE_INFINITY;
int emiteration = 0;
while (emiteration < 5 || newL - help > 0.01) {
++emiteration;
help = newL;
newMixture = nextIteration(observations, newMixture);
newL = countLnL(observations, newMixture);
if (getMinWeigth(newMixture) < minWeight) {
newMixture = newMixture.createSignificantComponentMixture(minWeight);
newL = countLnL(observations, newMixture);
}
// assign new L value if possible
if (newL > resL) {
res = newMixture;
stepListener.stepDone(res);
}
}
}
stepListener.stepDone(res);
return res;
}
/**
* Extracts hypercube from the specified observations.
*
* @param observations observations
* @return extracted interval
*/
private Hypercube extractHypercube(ObservationProvider observations) {
ObservationIterator iterator = observations.getIterator();
Hypercube.Builder res = new Hypercube.Builder();
while (iterator.isNextAvailable()) {
res.addVector(iterator.getNext());
}
return res.build();
}
/**
* Makes the one component initialization for the EM algorithm.
*
* @param observations observation for which the initialization should be calculated
* @param dimension dimension of the observations
* @return mixture with initial parameters
*/
private DiagonalNormalDistributionMixture initializeOneCompoenent(ObservationProvider observations, int dimension) {
// calculates first and second central momentum
double[] m1 = new double[dimension];
double[] m2 = new double[dimension];
ObservationIterator iterator = observations.getIterator();
while (iterator.isNextAvailable()) {
Vector x = iterator.getNext();
for (int i = 0; i < dimension; ++i) {
m1[i] += x.getComponent(i);
m2[i] += x.getComponent(i) * x.getComponent(i);
}
}
for (int i = 0; i < dimension; ++i) {
m1[i] /= iterator.getNumIterated();
m2[i] /= iterator.getNumIterated();
m2[i] -= m1[i] * m1[i];
m2[i] = Math.max(minSigma, Math.sqrt(m2[i]));
}
// determine the components
return new DiagonalNormalDistributionMixture.Builder().
addComponent(1, DiagonalNormalDistribution.create(Vector.create(m1), Vector.create(m2))).
build();
}
/**
* Makes one iteration of the EM algorithm. Returns the mixture after the iteration.
*
* @param observations observation for which the iteration should be calculated
* @param start starting position
* @return mixture after the iteration
*/
private DiagonalNormalDistributionMixture nextIteration(ObservationProvider observations, DiagonalNormalDistributionMixture start) {
int numComponents = start.getNumComponents();
int dim = start.getDimension();
double[] newW = new double[numComponents];
double[][] newMi = new double[numComponents][dim];
double[][] newSigma = new double[numComponents][dim];
double[] c = new double[numComponents];
double c0 = -Double.MAX_VALUE;
double h = 0;
double qmx = 0;
//double L = 0;
double[] mi = null;
double[] sigma = null;
// adding values
ObservationIterator iterator = observations.getIterator();
while (iterator.isNextAvailable()) {
Vector x = iterator.getNext();
c0 = -Double.MAX_VALUE;
for (int j = 0; j < numComponents; ++j) {
c[j] = Math.log(start.getWeights().get(j)) + start.getComponents().get(j).getLnValue(x);
if (c[j] > c0) {
c0 = c[j];
}
}
h = 0;
for (int j = 0; j < numComponents; ++j) {
c[j] -= c0;
if (c[j] > COMP_TRUNC) {
c[j] = Math.exp(c[j]);
h += c[j];
}
else {
c[j] = 0;
}
}
//L += Math.log(h) + c0;
for (int j = 0; j < numComponents; ++j) {
if (c[j] == 0) {
continue;
}
mi = newMi[j];
sigma = newSigma[j];
qmx = c[j] / h;
newW[j] += qmx;
for (int k = 0; k < dim; ++k) {
mi[k] += x.getComponent(k) * qmx;
sigma[k] += x.getComponent(k) * x.getComponent(k) * qmx;
}
}
}
// finishing
for (int i = 0; i < numComponents; ++i) {
mi = newMi[i];
sigma = newSigma[i];
for (int j = 0; j < dim; ++j) {
mi[j] /= newW[i];
sigma[j] = Math.max(minSigma, Math.sqrt(sigma[j] / newW[i] - mi[j] * mi[j]));
if (Double.valueOf(sigma[j]).equals(Double.NaN)) {
sigma[j] = minSigma;
}
}
newW[i] /= iterator.getNumIterated();
}
//System.out.println("L = " + (L / obs.size()));
// creating new instance
DiagonalNormalDistributionMixture.Builder builder = new DiagonalNormalDistributionMixture.Builder();
for (int i = 0; i < numComponents; ++i) {
builder.addComponent(newW[i], new DiagonalNormalDistribution.Builder().
setMi(Vector.create(newMi[i])).
setSigma(Vector.create(newSigma[i])).
build());
}
return builder.build();
}
/**
* Calculates the ln(L) value.
* Where L = prod_x( sum_i(w(i|x)P(i|x)) ) and ln(L) = sum_x( ln(sum_i(w(i|x)P(i|x))) ).
*
* @param observations observations
* @param mixture mixture
* @return ln(L) value
*/
private double countLnL(ObservationProvider observations, DiagonalNormalDistributionMixture mixture) {
double[] c = new double[mixture.getNumComponents()];
double c0 = -Double.MAX_VALUE;
double h = 0;
double L = 0;
// adding values
ObservationIterator iterator = observations.getIterator();
while (iterator.isNextAvailable()) {
Vector x = iterator.getNext();
c0 = -Double.MAX_VALUE;
for (int j = 0; j < mixture.getNumComponents(); ++j) {
c[j] = Math.log(mixture.getWeights().get(j)) + mixture.getComponents().get(j).getLnValue(x);
if (c[j] > c0) {
c0 = c[j];
}
}
h = 0;
for (int j = 0; j < mixture.getNumComponents(); ++j) {
c[j] -= c0;
if (c[j] > COMP_TRUNC) {
c[j] = Math.exp(c[j]);
h += c[j];
}
else {
c[j] = 0;
}
}
L += Math.log(h) + c0;
}
// not divide to fall back to the case what was before
//L = L / iterator.getNumIterated();
//System.out.println("L = " + L);
return L;
}
/**
* Returns minimal weight.
*
* @param mixture mixture
* @return minimal weight
*/
private double getMinWeigth(DiagonalNormalDistributionMixture mixture) {
double res = 1;
for (double w : mixture.getWeights()) {
if (w < res) {
res = w;
}
}
return res;
}
@Override
public String toString() {
return ToStringBuilder.reflectionToString(this);
}
}