date.iterator.count.isodata.Cluster Maven / Gradle / Ivy
package date.iterator.count.isodata;
import date.iterator.count.util.GlobalKeyCenter;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collector;
import java.util.stream.Collectors;
/*
* 类别
* */
public class Cluster {
private Point center;
private Point oldCenter;//不知道有没有用,先存着
private double averageDistance;
private List squaredErrors; // 应该没什么用
private List points = new ArrayList<>();
private Map pointTotalValues = new HashMap<>(); //各个分量值分别求和
public void clear() {
points.clear();
pointTotalValues.clear();
squaredErrors = null;
}
/*
* 最大标准差的分量
* */
public Pair maxStandardDeviation() {
double resultStandardDeviation = 0;
String resultProperty = "";
squaredErrors = new ArrayList<>();
for (String eachTitle : GlobalKeyCenter.INSTANCE.getKeys()) {
double eachSquaredError = 0;
for (Point eachPoint : points) {
Double curretValue = eachPoint.getValues().get(eachTitle);
Double average = pointTotalValues.get(eachTitle) / points.size();
eachSquaredError += Math.pow(Math.abs(curretValue - average), 2);
}
eachSquaredError = Math.sqrt(eachSquaredError / points.size());
squaredErrors.add(eachSquaredError);//应该没什么用,先存着玩
if (eachSquaredError > resultStandardDeviation) {
resultStandardDeviation = eachSquaredError;
resultProperty = eachTitle;
}
}
// System.out.println("maxStandardDeviation:" + resultStandardDeviation);
return new ImmutablePair(resultStandardDeviation, resultProperty);
}
/*
* 计算各点到中心的平均距离
* */
private void calculateAverageDistance() {
double totalDistance = 0;
for (Point eachPoint : points) {
totalDistance += eachPoint.distanceEuclidean(center);
}
averageDistance = totalDistance / points.size();
}
/*
* 计算中心位置,并更新值,这个中心位置值不一定有点?
* */
public void updateCenterValue() {
oldCenter = center;
// 除以 point个数
Point possiblePoint = new Point();
for (String eachTitle : GlobalKeyCenter.INSTANCE.getKeys()) {
double value = pointTotalValues.get(eachTitle);
possiblePoint.getValues().put(eachTitle, value / points.size());
}
this.center = possiblePoint;
calculateAverageDistance();
}
public Cluster setCenter(final Point center) {
this.center = center;
return this;
}
public Point getCenter() {
return center;
}
public Cluster setPoint(final Point point) {
this.points.add(point);
countPointValue(point);
return this;
}
public Cluster setPoints(final List points) {
this.points.addAll(points);
points.forEach(p -> countPointValue(p));
return this;
}
public List getPoints() {
return points;
}
public double getAverageDistance() {
return averageDistance;
}
// 新增时计算,未经变化的聚类可重复使用
private void countPointValue(Point point) {
for (String each : GlobalKeyCenter.INSTANCE.getKeys()) {
if (pointTotalValues.containsKey(each)) {
pointTotalValues.put(each, pointTotalValues.get(each) + point.getValues().get(each));
continue;
}
pointTotalValues.put(each, point.getValues().get(each));
}
}
@Override
public String toString() {
return "point size:" + points.size() + ", center:" + center.toString();
}
public Map getPointTotalValues() {
return pointTotalValues;
}
}