All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.etsy.conjecture.topics.lda.LDADoc Maven / Gradle / Ivy

There is a newer version: 0.2.3
Show newest version
package com.etsy.conjecture.topics.lda;

import com.etsy.conjecture.Utilities;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public class LDADoc implements Serializable {

    private static final long serialVersionUID = 1536967875771864807L;
    double[] topic_proportions;
    double total_words;
    int[] word_idx;
    double[] word_count;
    double[][] phi;
    boolean phi_dirty;

    public LDADoc(Map word_counts, LDADict dict) {
        total_words = 0.0;
        word_idx = new int[word_counts.size()];
        word_count = new double[word_counts.size()];
        phi_dirty = true;
        int i = 0;
        for (Map.Entry e : word_counts.entrySet()) {
            word_idx[i++] = dict.index(e.getKey());
            total_words += e.getValue();
        }
        // Keep parallel arrays in sorted order of word index, for easier
        // aggregation
        // of partial topic models.
        Arrays.sort(word_idx);
        for (int w = 0; w < word_idx.length; w++) {
            word_count[w] = word_counts.get(dict.word(word_idx[w]));
        }
    }

    public double[] topicProportions() {
        return topic_proportions;
    }

    public double wordCount() {
        return total_words;
    }

    public void updateTopicProportions(LDATopics topics, double alpha) {
        int K = topics.numTopics();
        // reuse old topic proportions unless the topic model has changed.
        if (topic_proportions == null
                || topic_proportions.length != topics.numTopics()) {
            topic_proportions = new double[K];
            for (int k = 0; k < K; k++) {
                topic_proportions[k] = total_words / (double)K;
            }
        }
        if (phi == null || phi[0].length != topics.numTopics()) {
            phi = new double[word_idx.length][K];
        }
        // iterate the update procedure.
        double[] topic_proportions_new = new double[K];
        double[] phi_z = new double[word_idx.length];
        while (true) {
            // Compute phi.
            for (int k = 0; k < K; k++) {
                double digamma_k = LDAUtils.digamma(topic_proportions[k]);
                for (int w = 0; w < word_idx.length; w++) {
                    double wp = Math.log(topics.wordProb(word_idx[w], k));
                    phi[w][k] = digamma_k + wp;
                    if (k == 0) {
                        phi_z[w] = phi[w][k];
                    } else {
                        phi_z[w] = LDAUtils.logSumExp(phi_z[w], phi[w][k]);
                    }
                }
            }
            // Compute updated gamma.
            double conv = 0.0;
            for (int k = 0; k < K; k++) {
                topic_proportions_new[k] = alpha;
                for (int w = 0; w < word_idx.length; w++) {
                    phi[w][k] = Math.exp(phi[w][k] - phi_z[w]) * word_count[w];
                    topic_proportions_new[k] += phi[w][k];
                }
                double diff = topic_proportions[k] - topic_proportions_new[k];
                topic_proportions[k] = topic_proportions_new[k];
                conv += diff * diff;
            }
            // Check convergence.
            if (conv < 1000.0) {
                break;
            }
        }
        phi_dirty = false;
    }

    // You can only call this after calling updateTopicProportions..
    public LDAPartialTopics toPartialTopics() throws Exception {
        if (phi_dirty) {
            throw new Exception(
                    "Called toPartialTopics() on a doc that hasnt been updated");
        }
        return new LDAPartialTopics(word_idx, phi);
    }

    public LDAPartialTopics toPartialTopic(int topic) throws Exception {
        if (phi_dirty) {
            throw new Exception(
                    "Called toPartialTopics() on a doc that hasnt been updated");
        }
        double[][] phi_k = new double[word_idx.length][1]; // duh
        for (int i = 0; i < word_idx.length; i++) {
            phi_k[i][0] = phi[i][topic];
        }
        return new LDAPartialTopics(word_idx, phi_k);
    }

    public LDAPartialSparseTopics toPartialSparseTopics(int n) throws Exception {
        if (phi_dirty) {
            throw new Exception(
                    "Called toPartialTopics() on a doc that hasnt been updated");
        }
        int K = topic_proportions.length;
        Map partial_phi = new HashMap();
        Map word_topic_prob = new HashMap();
        for (int w = 0; w < word_idx.length; w++) {
            word_topic_prob.clear();
            for (int k = 0; k < K; k++) {
                word_topic_prob.put(k, phi[w][k]);
            }
            ArrayList sorted_topics = Utilities.orderKeysByValue(
                    word_topic_prob, true);
            double z = 0.0;
            for (int i = 0; i < n; i++) {
                z += phi[w][sorted_topics.get(i)];
            }
            word_topic_prob.clear();
            for (int i = 0; i < n; i++) {
                int k = sorted_topics.get(i);
                int v = word_idx[w];
                partial_phi.put(v * K + k, (phi[w][k] / z) * word_count[w]);
            }
        }
        return new LDAPartialSparseTopics(K, partial_phi);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy