org.apache.mahout.clustering.OnlineGaussianAccumulator Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mahout-mr Show documentation
Show all versions of mahout-mr Show documentation
Scalable machine learning libraries
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.clustering;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.SquareRootFunction;
/**
* An online Gaussian statistics accumulator based upon Knuth (who cites Welford) which is declared to be
* numerically-stable. See http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
*/
public class OnlineGaussianAccumulator implements GaussianAccumulator {
private double sumWeight;
private Vector mean;
private Vector s;
private Vector variance;
@Override
public double getN() {
return sumWeight;
}
@Override
public Vector getMean() {
return mean;
}
@Override
public Vector getStd() {
return variance.clone().assign(new SquareRootFunction());
}
/* from Wikipedia: http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
*
* Weighted incremental algorithm
*
* def weighted_incremental_variance(dataWeightPairs):
* mean = 0
* S = 0
* sumweight = 0
* for x, weight in dataWeightPairs: # Alternately "for x in zip(data, weight):"
* temp = weight + sumweight
* Q = x - mean
* R = Q * weight / temp
* S = S + sumweight * Q * R
* mean = mean + R
* sumweight = temp
* Variance = S / (sumweight-1) # if sample is the population, omit -1
* return Variance
*/
@Override
public void observe(Vector x, double weight) {
double temp = weight + sumWeight;
Vector q;
if (mean == null) {
mean = x.like();
q = x.clone();
} else {
q = x.minus(mean);
}
Vector r = q.times(weight).divide(temp);
if (s == null) {
s = q.times(sumWeight).times(r);
} else {
s = s.plus(q.times(sumWeight).times(r));
}
mean = mean.plus(r);
sumWeight = temp;
variance = s.divide(sumWeight - 1); // # if sample is the population, omit -1
}
@Override
public void compute() {
// nothing to do here!
}
@Override
public double getAverageStd() {
if (sumWeight == 0.0) {
return 0.0;
} else {
Vector std = getStd();
return std.zSum() / std.size();
}
}
@Override
public Vector getVariance() {
return variance;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy