All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.models.word2vec.Huffman Maven / Gradle / Ivy

/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.models.word2vec;

import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;


/**
 * Huffman tree builder
 * @author Adam Gibson
 *
 */
public class Huffman {

    public final int MAX_CODE_LENGTH;
    private volatile boolean buildTrigger = false;

    private Logger logger = LoggerFactory.getLogger(Huffman.class);

    public Huffman(Collection words) {
        this(words, 40);
    }

    /**
     * Builds Huffman tree for collection of SequenceElements, with defined CODE_LENGTH
     * Default CODE_LENGTH is 40
     *
     * @param words
     * @param CODE_LENGTH CODE_LENGTH defines maximum length of code path, and effectively limits vocabulary size.
     */
    public Huffman(Collection words, int CODE_LENGTH) {
        this.MAX_CODE_LENGTH = CODE_LENGTH;
        this.words = new ArrayList<>(words);
        Collections.sort(this.words, new Comparator() {
            @Override
            public int compare(SequenceElement o1, SequenceElement o2) {
                return Double.compare(o2.getElementFrequency(), o1.getElementFrequency());
            }

        });
    }

    private List words;

    public void build() {
        buildTrigger = true;
        long[] count = new long[words.size() * 2 + 1];
        byte[] binary = new byte[words.size() * 2 + 1];
        byte[] code = new byte[MAX_CODE_LENGTH];
        int[] point = new int[MAX_CODE_LENGTH];
        int[] parentNode = new int[words.size() * 2 + 1];
        int a = 0;

        while (a < words.size()) {
            count[a] = (long) words.get(a).getElementFrequency();
            a++;
        }

        a = words.size();

        while (a < words.size() * 2) {
            count[a] = Integer.MAX_VALUE;
            a++;
        }

        int pos1 = words.size() - 1;
        int pos2 = words.size();

        int min1i;
        int min2i;

        a = 0;
        // Following algorithm constructs the Huffman tree by adding one node at a time
        for (a = 0; a < words.size() - 1; a++) {
            // First, find two smallest nodes 'min1, min2'
            if (pos1 >= 0) {
                if (count[pos1] < count[pos2]) {
                    min1i = pos1;
                    pos1--;
                } else {
                    min1i = pos2;
                    pos2++;
                }
            } else {
                min1i = pos2;
                pos2++;
            }
            if (pos1 >= 0) {
                if (count[pos1] < count[pos2]) {
                    min2i = pos1;
                    pos1--;
                } else {
                    min2i = pos2;
                    pos2++;
                }
            } else {
                min2i = pos2;
                pos2++;
            }

            count[words.size() + a] = count[min1i] + count[min2i];
            parentNode[min1i] = words.size() + a;
            parentNode[min2i] = words.size() + a;
            binary[min2i] = 1;
        }
        // Now assign binary code to each vocabulary word
        int i;
        int b;
        // Now assign binary code to each vocabulary word
        for (a = 0; a < words.size(); a++) {
            b = a;
            i = 0;
            do {
                code[i] = binary[b];
                point[i] = b;
                i++;
                b = parentNode[b];

            } while (b != words.size() * 2 - 2 && i < 39);


            words.get(a).setCodeLength((short) i);
            words.get(a).getPoints().add(words.size() - 2);

            for (b = 0; b < i; b++) {
                try {
                    words.get(a).getCodes().set(i - b - 1, code[b]);
                    words.get(a).getPoints().set(i - b, point[b] - words.size());
                } catch (Exception e) {
                    logger.info("Words size: [" + words.size() + "], a: [" + a + "], b: [" + b + "], i: [" + i
                                    + "], points size: [" + words.get(a).getPoints().size() + "]");
                    throw new RuntimeException(e);
                }
            }

        }


    }

    /**
     * This method updates VocabCache and all it's elements with Huffman indexes
     * Please note: it should be the same VocabCache as was used for Huffman tree initialization
     *
     * @param cache VocabCache to be updated.
     */
    public void applyIndexes(VocabCache cache) {
        if (!buildTrigger)
            build();

        for (int a = 0; a < words.size(); a++) {
            if (words.get(a).getLabel() != null) {
                cache.addWordToIndex(a, words.get(a).getLabel());
            } else {
                cache.addWordToIndex(a, words.get(a).getStorageId());
            }

            words.get(a).setIndex(a);
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy