io.github.repir.Strategy.RetrievalModelMBF Maven / Gradle / Ivy
The newest version!
package io.github.repir.Strategy;
import io.github.repir.Repository.DocForward;
import io.github.repir.Repository.DocTF;
import io.github.repir.Repository.Feature;
import io.github.repir.Repository.ReportableFeature;
import io.github.repir.Repository.Repository;
import io.github.repir.Repository.Stopwords.StopWords;
import io.github.repir.Repository.Term;
import io.github.repir.Repository.TermCF;
import io.github.repir.Repository.TermString;
import io.github.repir.Retriever.Document;
import io.github.repir.Retriever.Query;
import io.github.repir.Retriever.ReportedFeature;
import io.github.repir.Retriever.Retriever;
import io.github.repir.Strategy.Collector.CollectorDocument;
import io.github.repir.Strategy.Operator.Operator;
import io.github.repir.Strategy.Operator.QTerm;
import io.github.repir.Strategy.RetrievalModel;
import io.github.repir.Strategy.ScoreFunction.ScoreFunctionKLD;
import io.github.repir.tools.lib.Log;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeSet;
/**
* Implementation of Model Based Feedback proposed by Zhai and Lafferty (2001).
* @author jer
*/
public class RetrievalModelMBF extends RetrievalModel {
public static Log log = new Log(RetrievalModelMBF.class);
public int fbmaxdocs;
public double lambda;
public double alpha;
public boolean fbstopwords;
public int fbterms;
public RetrievalModelMBF(Retriever retriever) {
super(retriever);
lambda = repository.getConf().getFloat("mbf.lambda", 0.95f);
alpha = repository.getConf().getFloat("mbf.alpha", 0.5f);
fbmaxdocs = repository.getConf().getInt("mbf.fbdocs", 10);
fbstopwords = repository.getConf().getBoolean("mbf.fbstopwords", false);
fbterms = repository.getConf().getInt("mbf.fbterms", 1000);
}
@Override
public String getQueryToRetrieve() {
query.setScorefunctionClassname(ScoreFunctionKLD.class.getSimpleName());
return super.getQueryToRetrieve();
}
@Override
public ArrayList getReportedFeatures() {
ArrayList features = new ArrayList();
features.add(Feature.canonicalName(DocForward.class, "all")); // need to test
return features;
}
@Override
public int getDocumentLimit() {
return fbmaxdocs;
}
@Override
public String getScorefunctionClass() {
return ScoreFunctionKLD.class.getSimpleName();
}
/**
* The model is re-estimated by taking the top-k documents returned in the first
* retrieval pass, using an Expectation Maximization algorithm to estimate whether
* words are more likely to originate from the pseudo-relevance documents or the
* general collection model.
* @return
*/
@Override
public Query finishReduceTask() {
// fill fbterm with the words in the top-fbmaxdocs documents
FBModel fbterm = new FBModel(this, retriever, fbmaxdocs, fbstopwords);
FBModel fbmax = null;
// EM to estimate p based on feedback
for (int i = 0; i < 20; i++) {
fbterm.EM(lambda);
if (fbmax == null || fbmax.score < fbterm.score) {
fbmax = fbterm.clone();
}
}
for (Operator f : root.containednodes) {
if (f instanceof QTerm) {
QTerm ft = (QTerm) f;
if (ft.exists()) {
T t = fbterm.get(ft.term.getID());
log.info("existing term %s mle %f", t.term, t.p);
}
}
}
double sumoldtermweight = root.containednodes.size();
// addQueue fbmaxdocs terms to root
TreeSet sorted = new TreeSet(fbmax.values());
int expandterms = 0;
TreeSet newterms = new TreeSet();
for (T term : sorted) {
// addQueue to querymodel: existing terms and terms above the cutoff point
QTerm e = null;
QTerm n = root.getTerm(term.term.getProcessedTerm());
Operator f = root.find(n);
if (f != null) {
newterms.add(new T(term.term, (1 - alpha) * f.getQueryWeight() / sumoldtermweight + alpha * term.p));
} else if (term.p >= 0.001 && expandterms++ < fbterms) {
newterms.add(new T(term.term, alpha * term.p));
}
}
double sumweight = 0;
for (T t : newterms) {
sumweight += t.p;
}
for (T t : newterms) {
t.p /= sumweight;
}
StringBuilder sb = new StringBuilder();
for (T t : newterms) {
sb.append(t.term).append("#").append(t.p).append(" ");
}
query.query = sb.toString();
query.setStrategyClassname("RetrievalModel");
query.removeStopwords = false;
query.clearResults();
return query;
}
static class FBModel extends HashMap {
double score;
private FBModel() {
}
public FBModel(RetrievalModel rm, Retriever retriever, int fb, boolean fbstopwords) {
Repository repository = retriever.repository;
ReportedFeature forward = rm.getReportedFeature(DocForward.class, "all");
TermCF termcf = TermCF.get(repository);
termcf.loadMem();
for (Operator f : rm.root.containednodes) {
if (f instanceof QTerm) {
QTerm term = (QTerm) f;
if (term.exists()) {
T t = new T(term.term, term.getCF() / (double)repository.getCF());
put(term.term, t);
}
}
}
int doccount = 0;
int showterm = 0;
TermString termstring = TermString.get(repository);
termstring.openRead();
for (Document d : ((CollectorDocument) rm.collectors.get(0)).getRetrievedDocs()) {
if (doccount++ >= fb) {
break;
}
int tokens[] = d.getIntArray(forward);
for (int termid : tokens) {
T t = get(termid);
if (t == null) {
long cf = termcf.readValue(termid);
Term term = repository.getTerm(termid);
if (fbstopwords || !term.isStopword()) {
t = new T(term, cf / (double) retriever.getRepository().getCF());
t.cf = 1;
put(term, t);
}
} else {
t.cf++;
}
}
}
}
@Override
public FBModel clone() {
FBModel fb = new FBModel();
for (Map.Entry entry : entrySet()) {
T t = new T(entry.getValue().term, entry.getValue().termcorpusmle);
t.p = entry.getValue().p;
t.weight = entry.getValue().weight;
fb.put(t.term, t);
}
fb.score = score;
return fb;
}
public void EM(double lambda) {
double sum = 0;
for (T term : values()) {
term.p = term.cf;//Lib.Random.getDouble();
sum += term.p;
}
for (T term : values()) {
term.p /= sum;
}
double diff = 1;
while (diff > 0.001) {
diff = 0;
// E-step
for (T term : values()) {
double newweight = (1 - lambda) * term.p / ((1 - lambda) * term.p + lambda * term.termcorpusmle);
diff += Math.abs(term.weight - newweight);
term.weight = newweight;
}
// M-step
double sump = 0;
for (T term : values()) {
term.p = term.cf * term.weight;
sump += term.p;
}
for (T term : values()) {
term.p /= sump;
}
}
}
}
static class T implements Comparable {
Term term;
double termcorpusmle;
int cf;
double weight;
double p;
double querymle;
public T(Term term, double corpusmle) {
this.term = term;
termcorpusmle = p = corpusmle;
}
@Override
public int compareTo(T o) {
return (p < o.p) ? 1 : -1;
}
}
static class NEWTERM {
Term term;
double p;
public NEWTERM(Term term, double p) {
this.term = term;
this.p = p;
term.hashCode();
}
public int hashCode() {
return term.getID();
}
public boolean equals(Object o) {
return (o instanceof NEWTERM && ((NEWTERM) o).term.getID() == term.getID());
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy