All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.wikibrain.sr.phrasesim.CosimilarityMatrix Maven / Gradle / Ivy

There is a newer version: 0.9.1
Show newest version
package org.wikibrain.sr.phrasesim;

import gnu.trove.list.TIntList;
import gnu.trove.map.TIntIntMap;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.procedure.TIntProcedure;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import org.wikibrain.sr.utils.Leaderboard;

import java.io.Serializable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * A cosimilarity matrix that is dense, but can be expanded.
 *
 * @author Shilad Sen
 */
public class CosimilarityMatrix implements Serializable {
    private static final double EXPANSION_FRACTION = 1.3;

    private static final Logger LOGGER = LoggerFactory.getLogger(CosimilarityMatrix.class);

    private float[][] matrix = new float[0][0];
    private TIntIntMap sparse2Dense = new TIntIntHashMap();
    private int[] dense2Sparse = new int[0];
    private boolean [] completed = new boolean[0];

    public synchronized void update(int sparseId, SRResultList neighbors) {
        int denseId;
        if (sparse2Dense.containsKey(sparseId)) {
            denseId = sparse2Dense.get(sparseId);
        } else {
            expandIfNecessary();
            denseId = sparse2Dense.size();
            sparse2Dense.put(sparseId, denseId);
            dense2Sparse[denseId] = sparseId;
        }

        for (SRResult r : neighbors) {
            if (!sparse2Dense.containsKey(r.getId())) continue;
            int denseId2 = sparse2Dense.get(r.getId());
            matrix[denseId][denseId2] = (float) r.getScore();
            matrix[denseId2][denseId] = (float) r.getScore();
        }
        completed[denseId] = true;
    }

    public float[] getVector(int id) {
        if (sparse2Dense.containsKey(id)) {
            return matrix[sparse2Dense.get(id)];
        } else {
            return null;
        }
    }

    public synchronized int size() {
        return sparse2Dense.size();
    }

    public SRResultList mostSimilar(int id, int maxResults, TIntSet candidateIds) {
        int n;
        int denseId;
        boolean denseCandidateIds[] = null;
        synchronized (this) {
            n = sparse2Dense.size();    // this can be lock free because of trove's implementation
            denseId = sparse2Dense.get(id);
            if (candidateIds != null) {
                denseCandidateIds = new boolean[candidateIds.size()];
                final boolean[] finalDenseCandidateIds = denseCandidateIds;
                candidateIds.forEach(new TIntProcedure() {
                    @Override
                    public boolean execute(int id2) {
                        finalDenseCandidateIds[sparse2Dense.get(id2)] = true;
                        return true;
                    }
                });
            }
        }
        Leaderboard top = new Leaderboard(maxResults);
        for (int i = 0; i < n; i++) {
            if (denseCandidateIds == null || denseCandidateIds[i]) {
                top.tallyScore(i, matrix[denseId][i]);
            }
        }
        SRResultList results = top.getTop();
        for (int i = 0; i < results.numDocs(); i++) {
            results.setId(i, dense2Sparse[results.getId(i)]);
        }
        return results;
    }


    public double[][] cosimilarity(int rows[], int columns[]) {
        double cosims[][] = new double[rows.length][columns.length];
        int denseRowIds[] = new int[rows.length];
        int denseColIds[] = new int[columns.length];
        synchronized (this) {
            for (int i = 0; i < rows.length; i++) {
                int rowId = rows[i];
                denseRowIds[i] = sparse2Dense.containsKey(rowId) ? sparse2Dense.get(rowId) : -1;
            }
            for (int i = 0; i < columns.length; i++) {
                int colId = columns[i];
                denseColIds[i] = sparse2Dense.containsKey(colId) ? sparse2Dense.get(colId) : -1;
            }
        }
        for (int i = 0; i < denseRowIds.length; i++) {
            for (int j = 0; j < denseColIds.length; j++) {
                if (denseRowIds[i] >= 0 && denseColIds[j] >= 0) {
                    cosims[i][j] = matrix[denseRowIds[i]][denseColIds[j]];
                }
            }
        }
        return cosims;
    }


    private synchronized void expandIfNecessary() {
        if (sparse2Dense.size() < dense2Sparse.length) {
            return;
        }
        if (dense2Sparse.length != sparse2Dense.size()) {
            throw new IllegalStateException();
        }
        if (dense2Sparse.length != matrix.length) {
            throw new IllegalStateException();
        }
        if (dense2Sparse.length > 0 && dense2Sparse.length != matrix[0].length) {
            throw new IllegalStateException();
        }
        int oldn = sparse2Dense.size();
        final int n = (int) (Math.max(500, oldn) * EXPANSION_FRACTION);
        LOGGER.info("expanding cosimilarity matrix to length " + n);
        float newCosims[][] = new float[n][n];
        for (int i = 0; i < oldn; i++) {
            System.arraycopy(matrix[i], 0, newCosims[i], 0, oldn);
        }
        int newDense2Sparse[] = new int[n];
        boolean newCompleted[] = new boolean[n];
        System.arraycopy(dense2Sparse, 0, newDense2Sparse, 0, oldn);
        System.arraycopy(newCompleted, 0, newCompleted, 0, oldn);
        matrix = newCosims;
        dense2Sparse = newDense2Sparse;
        completed = newCompleted;

        LOGGER.info("finished expanding cosimilarity matrix");
    }

    public float similarity(int id1, int id2) {
        if (sparse2Dense.containsKey(id1) && sparse2Dense.containsKey(id2)) {
            return matrix[sparse2Dense.get(id1)][sparse2Dense.get(id2)];
        } else {
            return 0f;
        }
    }

    /**
     * Returns all ids whose "update" method has successfully completed.
     * @return
     */
    public TIntSet getCompleted() {
        TIntSet result = new TIntHashSet();
        for (int i = 0; i < completed.length; i++) {
            if (completed[i]) {
                result.add(dense2Sparse[i]);
            }
        }
        return result;
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy