
org.wikibrain.sr.ensemble.LinearEnsemble Maven / Gradle / Ivy
package org.wikibrain.sr.ensemble;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.map.hash.TIntDoubleHashMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.apache.commons.io.FileUtils;
import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
import org.wikibrain.sr.SRResult;
import org.wikibrain.sr.SRResultList;
import java.io.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
*@author Matt Lesicko
*/
public class LinearEnsemble implements Ensemble{
private static final Logger LOG = LoggerFactory.getLogger(LinearEnsemble.class);
final int numMetrics;
private final int numTrainingCandidateArticles;
TDoubleArrayList simlarityCoefficients;
TDoubleArrayList mostSimilarCoefficients;
Interpolator similarityInterpolator;
Interpolator mostSimilarInterpolator;
public LinearEnsemble(int numMetrics, int numTrainingCandidateArticles){
this.numTrainingCandidateArticles = numTrainingCandidateArticles;
this.numMetrics = numMetrics;
simlarityCoefficients = new TDoubleArrayList();
simlarityCoefficients.add(0.0);
for (int i=0; i simList) {
if (simList.isEmpty()) {
throw new IllegalArgumentException("no examples to train on!");
}
similarityInterpolator.trainSimilarity(simList);
double[][] X = new double[simList.size()][numMetrics];
double[] Y = new double[simList.size()];
for (int i = 0; i simList) {
if (simList.isEmpty()){
throw new IllegalStateException("no examples to train on!");
}
mostSimilarInterpolator.trainMostSimilar(simList);
// Remove things that have no observed metrics
List pruned = new ArrayList();
for (EnsembleSim es : simList) {
if (es != null && es.getNumMetricsWithScore() > 0) {
pruned.add(es);
}
}
double[][] X = new double[pruned.size()][numMetrics*2];
double[] Y = new double[pruned.size()];
for (int i=0; i scores) {
if (scores.size()+1!= simlarityCoefficients.size()){
throw new IllegalStateException();
}
double weightedScore = simlarityCoefficients.get(0);
for (int i=0; i scores, int maxResults, TIntSet validIds) {
if (2*scores.size()+1!= mostSimilarCoefficients.size()){
throw new IllegalStateException();
}
TIntSet allIds = new TIntHashSet(); // ids returned by at least one metric
for (SRResultList resultList : scores){
if (resultList != null) {
for (SRResult result : resultList){
allIds.add(result.getId());
}
}
}
TIntDoubleHashMap scoreMap = new TIntDoubleHashMap();
for (int id : allIds.toArray()) {
scoreMap.put(id, mostSimilarCoefficients.get(0));
}
int i =1;
for (SRResultList resultList : scores){
TIntSet unknownIds = new TIntHashSet(allIds);
double c1 = mostSimilarCoefficients.get(i); // score coeff
double c2 = mostSimilarCoefficients.get(i+1); // rank coefficient
if (resultList != null) {
for (int j = 0; j < resultList.numDocs(); j++) {
int rank = j + 1;
// expand or contract ranks proportionately
if (validIds != null) {
double k = 1.0 * numTrainingCandidateArticles / validIds.size();
rank = (int) (rank * k);
}
SRResult result = resultList.get(j);
unknownIds.remove(result.getId());
double value = c1 * result.getScore() + c2 * Math.log(rank);
if (debug) {
System.err.format("%s %d. %.3f (id=%d), computing %.3f * %.3f + %.3f * (log(%d) = %.3f)\n",
"m" + i, j, value, result.getId(),
c1, result.getScore(), c2, rank, Math.log(rank));
}
scoreMap.adjustValue(result.getId(), value);
}
}
// interpolate scores for unknown ids
double value = c1 * mostSimilarInterpolator.getInterpolatedScore(i/2)
+ c2 * Math.log(mostSimilarInterpolator.getInterpolatedRank(i/2));
for (int id : unknownIds.toArray()) {
scoreMap.adjustValue(id, value);
}
i+=2;
}
List resultList = new ArrayList();
for (int id : scoreMap.keys()){
resultList.add(new SRResult(id,scoreMap.get(id)));
}
Collections.sort(resultList);
Collections.reverse(resultList);
int size = maxResults>resultList.size()? resultList.size() : maxResults;
SRResultList result = new SRResultList(size);
for (i=0; i
© 2015 - 2025 Weber Informatics LLC | Privacy Policy