All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.mayabot.nlp.algorithm.TopMaxK Maven / Gradle / Ivy

package com.mayabot.nlp.algorithm;

import com.mayabot.nlp.common.Pair;

import java.lang.reflect.Array;
import java.util.ArrayList;

import static java.lang.Math.min;

/**
 * 求最大Top K
 * 内部是小顶堆
 *
 * @author jimichan
 */
public class TopMaxK {
    private final Class clazz;
    private int k = 10;


    private float[] heap;
    private T[] idIndex;

    int size = 0;

    public TopMaxK(int k, Class clazz) {
        this.k = k;
        heap = new float[k];
        this.clazz = clazz;
        idIndex = (T[]) Array.newInstance(clazz, k);

    }


    public void push(T id, float score) {
        if (size < k) {
            heap[size] = score;
            idIndex[size] = id;
            size++;

            if (size == k) {
                buildMinHeap();
            }
        } else {
            // 如果这个数据大于最下值,那么有资格进入
            if (score > heap[0]) {
                heap[0] = score;
                idIndex[0] = id;
                mintopify(0);
            }
        }
    }

    public ArrayList> result() {
        int top = min(k, size);
        ArrayList> list = new ArrayList<>(top);
        for (int i = 0; i < top; i++) {
            list.add(new Pair<>(idIndex[i], heap[i]));
        }

        list.sort((a, b) -> -1 * Float.compare(a.second, b.second));
        return list;
    }

    private void buildMinHeap() {
        for (int i = k / 2 - 1; i >= 0; i--) {
            // 依次向上将当前子树最大堆化
            mintopify(i);
        }
    }

    /**
     * 让heap数组符合堆特性
     *
     * @param i
     */
    private void mintopify(int i) {
        int l = 2 * i + 1;
        int r = 2 * i + 2;
        int min = 0;

        if (l < k && heap[l] < heap[i]) {
            min = l;
        } else {
            min = i;
        }

        if (r < k && heap[r] < heap[min]) {
            min = r;
        }

        if (min == i || min >= k) {
            // 如果largest等于i说明i是最大元素
            // largest超出heap范围说明不存在比i节点大的子女
            return;
        }

        swap(i, min);
        mintopify(min);
    }

    private void swap(int i, int j) {
        float tmp = heap[i];
        heap[i] = heap[j];
        heap[j] = tmp;

        T tmp2 = idIndex[i];
        idIndex[i] = idIndex[j];
        idIndex[j] = tmp2;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy