All Downloads are FREE. Search and download functionalities are using the official Maven repository.

date.iterator.count.isodata.ISOData Maven / Gradle / Ivy

There is a newer version: 1.1.5
Show newest version
package date.iterator.count.isodata;

import com.saaavsaaa.client.event.EventCenter;
import date.iterator.count.event.PlotEvent;
import date.iterator.count.util.CalculationUtil;
import date.iterator.count.util.Distance;
import date.iterator.count.util.DistanceList;
import date.iterator.count.util.GlobalKeyCenter;
import org.apache.commons.lang3.tuple.Pair;

import java.util.*;

/*
* 代码未做异常处理,为了获得错误
* */
public class ISOData {

    private static double invalid_value = 0;//单个维度汇总值不能为0,转向量时有问题?
    //todo isodata不记维度,只记差值,算法原理,如何调整计算,节省空间
    private int expectK = 15;  // 预期的聚类中心数目;
    private int totalLoopI = 10000; // 迭代运算的次数。
    private double theta_S = 0.15; //θS 一个类中样本距离分布的标准差阈值。类内最大标准差分量应小于 θs
    private double theta_c = 1; //θc 两个聚类中心间的最小距离,若小于此数,两个聚类需进行合并;

    private List clusters = new ArrayList<>();
    private List points;

    private int pointSize = 0;
    private boolean debuging = false;

    private int constantLoop = 0;
    private boolean constant = true;

    private static boolean Use_Standard_Distance = true;
    private Map totalAverageValues = new HashMap<>();
    private Map totalStandardErrors = new HashMap<>();

    public ISOData(final int expectK, final int totalLoopI, final double theta_S, final double theta_c, final int initK, final List points) {
        this.expectK = expectK;
        this.totalLoopI = totalLoopI;
        this.theta_S = theta_S;
        this.theta_c = theta_c;

        this.points = points;
        init(initK, points);
    }

    private void debugShow(String position) {
        if (!debuging) {
            return;
        }
        try {
            EventCenter.INSTANCE.sent(new PlotEvent(clusters));
            Thread.sleep(1000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        int size = 0;
        for (Cluster eachCluster : clusters) {
            size += eachCluster.getPoints().size();
            System.out.println(eachCluster.toString());
        }
        if (size > pointSize) {
            System.out.println();
        }
        System.out.println(position + "-----------------------------");
    }

    public List calculateAdditionClusters(final List existentClusters) {
        this.clusters.addAll(existentClusters);
        existentClusters.forEach(c -> {
            this.points.addAll(c.getPoints());
            c.clear();
        });
        return calculate();
    }

    public List calculateAdditionPoints(final List newPoints) {
        this.points.addAll(newPoints);
        return calculate();
    }

    // 商品小分类
    // 客户价值高且avg_interval与最后一次购买时间的差值
    //todo 多次传入参数,包括初始化参数
    public List calculate() {
        //1
        this.pointSize = points.size();
        boolean update = false;

        for (int loop = 0; loop < totalLoopI; loop++) {
            if (constantLoop > 3) {
//                break;
            }
            if (update) {
                clusters.forEach(cluster -> cluster.clear());
                update = false;
            }

            //2
            prepareStandardData(points);
            initClusterDistribution(points);
            debugShow("initClusterDistribution");

            cancelTinyClusters();
            debugShow("cancelTinyClusters");

            updateClusterCenter();
            debugShow("updateClusterCenter");
            double totalAverage = calculatePointsDistance();

            //3
            if (clusters.size() <= expectK /2) {
                splitCluster(totalAverage);
                debugShow("分裂");
            }

            if (loop % 2 == 0 || clusters.size() > 2* expectK) {
                calculateCenterDistance(loop);
                debugShow("合并");
            }
            update = true;
            if (!constant) {
                constantLoop += 1;
            }
        }
        System.out.println("constantLoop : " + constantLoop);
        return clusters;
    }

    /*
     * 1.先随便选K个中心
     * */
    private void init(final int initK, final List points) {
        int size = points.size() - 3; //随便减一下,第一步不需要准确的结果
        for (int i = 0; i < initK; i++) {
            int kind = size / expectK;
            if (kind == 0) {
                kind = 1;
            }
            Point centerCurrent = points.get(i % kind);
            clusters.add(new Cluster().setCenter(centerCurrent));
        }
    }

    /*
     * 2.1 循环所有节点,取距离最近的中心加入
     * */
    private void initClusterDistribution(final List points) {
        initClusterDistribution(clusters, points);
    }

    private void initClusterDistribution(final List clusters, final List points) {
        points.forEach(eachPoint -> {
            double distance = Double.MAX_VALUE;
            Cluster currentCluster = clusters.get(0);//初始化
            // 将当前点加入到所有聚类中欧式距离最短的
            for (Cluster eachCluster : clusters) {
                Point currentCenter = eachCluster.getCenter();
//                double currentDistance = currentCenter.distanceEuclidean(eachPoint);
                double currentDistance = Use_Standard_Distance ?
                        CalculationUtil.distanceStandardEuclidean(eachPoint, currentCenter, totalStandardErrors)
                        : currentCenter.distanceEuclidean(eachPoint);
                if (currentDistance < 0) {
                    throw new ArithmeticException("欧式距离值溢出:" + eachPoint.toString());
                }
                if (distance > currentDistance) {
                    distance = currentDistance;
                    currentCluster = eachCluster;
                }
            }
            currentCluster.setPoint(eachPoint);
        });
    }

    /*标准欧式距离需要全体均值和标准差*/
    private void prepareStandardData(final List points) {
        if (Use_Standard_Distance && totalAverageValues.isEmpty()) {
            Map pointTotalValues = CalculationUtil.calculateTotalAverages(points);
            for (String eachKey : pointTotalValues.keySet()) {
                Double value = pointTotalValues.get(eachKey);
                if (value == invalid_value) {
                    GlobalKeyCenter.INSTANCE.deleteKey(eachKey);
                    continue;
                }
                totalAverageValues.put(eachKey, value / pointSize);
            }

            totalStandardErrors = CalculationUtil.calculateStandardDeviation(points, totalAverageValues);
        }
    }

    /*
     * 2.2 取消很小的分类
     * */
    private void cancelTinyClusters() {
        List points = new ArrayList<>();
        List tinyClusters = new ArrayList<>();
        for (Cluster eachCluster : clusters) {
            //int L = 100; // 在一次迭代运算中可以合并的聚类中心的最多对数;
            //θN 每一聚类域中最少的样本数目,若少于此数即不作为一个独立的聚类;
            int theta_N = 2;
            if (eachCluster.getPoints().size() < theta_N) {
                if (eachCluster.getPoints().size() == 0) {
                    points.addAll(eachCluster.getPoints());
                    eachCluster.clear();
                }
                tinyClusters.add(eachCluster);
            }
        }
        if (tinyClusters.size() > 0) {
            clusters.removeAll(tinyClusters);
            if (points.size() < 2) {
                clusters.get(clusters.size() - 1).setPoints(points);
            } else {
                clusters.add(new Cluster().setCenter(points.get(0)).setPoints(points));
            }
        }
    }

    /*
     * 2.3 更新每个类别的中心位置
     * */
    private void updateClusterCenter() {
        clusters.forEach(Cluster::updateCenterValue);
    }

    /*
     * 2.4 计算样本到各聚类中心的距离
     * 计算聚类内平均值,再将平均值加总取平均
     * */
    private double calculatePointsDistance() {
        List allPoints = new ArrayList<>();
        double eachDistanceTotal = 0D;
        for (Cluster eachCluster : clusters) {
            allPoints.addAll(eachCluster.getPoints());
            eachDistanceTotal += eachCluster.getAverageDistance();
        }
        double eachDistanceAverage = eachDistanceTotal / clusters.size();
        return eachDistanceAverage;
    }

    void calculateAllAverage() {
        List allPoints = new ArrayList<>();
        double eachDistanceTotal = 0D;
        for (Cluster eachCluster : clusters) {
            allPoints.addAll(eachCluster.getPoints());
            eachDistanceTotal += eachCluster.getAverageDistance();
        }
        //计算所有距离的平均值,先留着
        List allDistances = new ArrayList<>();
        allPoints.forEach(point -> allDistances.addAll(point.calculateCenterDistances(clusters)));
        double allDistanceValue = 0;
        for (Double eachDistance : allDistances) {
            allDistanceValue += eachDistance;
        }
        double averageDistance = allDistanceValue / allDistances.size();
    }

    /*
     * 3.1 类分裂
     * */
    private boolean splitCluster(final double totalAverage) {
        boolean result = false;
        List deleteCluster = new ArrayList<>();
        List newCluster = new ArrayList<>();
        for (Cluster eachCluster : clusters) {
            Pair top = eachCluster.maxStandardDeviation();
            System.out.println("StandardDeviation : " + top.getLeft());
            if (top.getLeft() > theta_S &&
                    //todo 下面这个条件恒真,因为判断聚类数量之后才进的这个方法,确认是否去掉
                    ((eachCluster.getAverageDistance() > totalAverage) || (clusters.size() < expectK / 2))) {
                //分裂 似乎都是用的超参数做分裂系数,我打算在最大分量和最小分量间连线,按线的左和非左分,不过目前时间紧,先0.5回头再说
                Point newCenter1 = eachCluster.getCenter().Copy().displaceAxisOpposite(top.getRight(), 0.5 * top.getLeft());
                Point newCenter2 = eachCluster.getCenter().Copy().displaceAxisOpposite(top.getRight(), - 0.5 * top.getLeft());
                List newClusters = new ArrayList<>();
                newClusters.add(new Cluster().setCenter(newCenter1));
                newClusters.add(new Cluster().setCenter(newCenter2));
                initClusterDistribution(newClusters, eachCluster.getPoints());
                deleteCluster.add(eachCluster);
                newCluster.addAll(newClusters);
                result = true;
                constant = false;
            }
        }

        clusters.removeAll(deleteCluster);
        deleteCluster.clear();
        clusters.addAll(newCluster);
        newCluster.clear();

        return result;
    }

    /*
     * 3.2 合并
     * */
    private void calculateCenterDistance(final int loop) {
        DistanceList centerDistances = new DistanceList();
        //小于规定距离的合并
        //取任意两个中心的距离,小于规定距离,保存下来
        //判断是否有重合点,如果有重合,合并其中较小的或先来的
        int index = 0;//保存入选在列表中的索引
        for (int i = 0; i < clusters.size(); i++) { //n! 新前面中心索引
            Point center1 = clusters.get(i).getCenter();
            for (int j = i + 1; j < clusters.size(); j++) { // 新后面中心索引
                Point center2 = clusters.get(j).getCenter();
                double currentDistance = center1.distanceEuclidean(center2);
                System.out.println("loop:" + loop + " currentDistance : " + currentDistance);
                if (currentDistance < theta_c) {
                    centerDistances.add(i, j, currentDistance);
                }
            }
        }

        Collection distances = centerDistances.getCenterDistances();

        //合并
        for (Distance each : distances) {
            int i1 = each.getKeys()[0]; // 保存了前面的节点索引
            int i2 = each.getKeys()[1]; // 保存了后面的节点索引
            Point center1 = clusters.get(i1).getCenter();
            Point center2 = clusters.get(i2).getCenter();
            Point newCenter = new Point();
            for (String eachProperty : GlobalKeyCenter.INSTANCE.getKeys()) {
                double p = center1.getValues().get(eachProperty) * center1.getValues().size();
                double s = center2.getValues().get(eachProperty) * center2.getValues().size();
                newCenter.getValues().put(eachProperty, (p+s) / (center1.getValues().size()+center2.getValues().size()));
            }
            clusters.get(i1).setCenter(newCenter);
            //todo 不是最后一次循环,可以不合并普通点
            if (loop == totalLoopI) {
                clusters.get(i1).setPoints(clusters.get(i2).getPoints());
                clusters.get(i2).clear();
                clusters.remove(i2);
            }
            constant = false;
        }
    }

    //调用四次传相反参数
    private boolean checkCenterDistanceExistOne(final double currentDistance,
                                                final List clusterIndex1, final List clusterIndex2,
                                                final int i1, final int i2, final Map centerDistances) {
        // 两个都存在的已经判断过了,只剩下存在一个的了
        if (clusterIndex1.contains(i1)) {
            int keyI1 = clusterIndex1.indexOf(i1);
            double existDistanceI1 = centerDistances.get(keyI1);
            if (existDistanceI1 > currentDistance) {
                clusterIndex2.set(keyI1, i2); // 更新key1对应的后面中心的序号
                centerDistances.put(keyI1, currentDistance);
                return true;
            }
        }
        return false;
    }

    //调用四次传,两次传相反参数,两次传两个List相同
    // added 是否加入新距离
    // result 大范围处理后不再需要继续判断
    private boolean checkCenterDistanceExistTwo(final double currentDistance,
                                                final List clusterIndex1, final List clusterIndex2,
                                                final int i1, final int i2, final Map centerDistances) {
        boolean result = false;
        if (clusterIndex1.contains(i1) && clusterIndex2.contains(i2)) { // 新的距离的前面点用过,同时后面点再另外一个距离中用过
            int keyI1 = clusterIndex1.indexOf(i1); // 保存了前面的节点索引
            int keyI2 = clusterIndex2.indexOf(i2); // 保存了后面的节点索引
            double existDistanceI1 = centerDistances.get(keyI1);
            double existDistanceI2 = centerDistances.get(keyI2);

            // clusterIndex1       clusterIndex2
            // key1 i1             key1 other_value
            // key2 other_value    key2 i2
            if (existDistanceI1 > existDistanceI2) { //如果前面中心的距离大,则删掉前面点
                deleteCenterDistance(clusterIndex1, clusterIndex2, keyI1, centerDistances);
                if (existDistanceI2 > currentDistance) { // 新加入的距离最短,更新key2的距离为新距离,中心索引为新中心索引
                    clusterIndex1.set(keyI2, i1); // 更新key2对应的前面中心的序号
                    centerDistances.put(keyI2, currentDistance);//替换key2为新距离
                    result = true;
                }
            } else {
                deleteCenterDistance(clusterIndex1, clusterIndex2, keyI2, centerDistances);
                if (existDistanceI1 > currentDistance) { // 新加入的距离最短
                    clusterIndex2.set(keyI1, i2); // 更新key1对应的后面中心的序号
                    centerDistances.put(keyI1, currentDistance);
                    result = true;
                }
            }
        }
        return result;
    }

    private void deleteCenterDistance(final List clusterIndex1, final List clusterIndex2,
                                      final int keyI, final Map centerDistances) {
        centerDistances.remove(keyI);
        clusterIndex1.remove(keyI);
        if (clusterIndex1 == clusterIndex2) {
            return;
        }
        clusterIndex2.remove(keyI);
    }

    public List getClusters() {
        return clusters;
    }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy