All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.anserini.ltr.feature.IbmModel1 Maven / Gradle / Ivy

/*
 * Anserini: A Lucene toolkit for reproducible information retrieval research
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.anserini.ltr.feature;

import io.anserini.ltr.DocumentContext;
import io.anserini.ltr.DocumentFieldContext;
import io.anserini.ltr.FeatureExtractor;
import io.anserini.ltr.QueryContext;
import io.anserini.ltr.QueryFieldContext;
import org.apache.commons.lang3.tuple.Pair;

import java.io.BufferedReader;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class IbmModel1 implements FeatureExtractor {
    private ConcurrentHashMap> sourceVoc;
    private ConcurrentHashMap sourceLookup;
    private ConcurrentHashMap> targetVoc;
    private ConcurrentHashMap targetLookup;
    private ConcurrentHashMap> tran;
    private double selfTrans = 0.05;
    private double lambda = 0.1;
    private double minProb = 5e-4;
    private String field;
    private String qfield;
    private String tag;

    public IbmModel1(String dir, String field, String tag, String qfield) throws IOException {
        sourceVoc = this.loadVoc(dir + File.separator + "source.vcb");
        assert !sourceVoc.containsKey("@NULL@");
        sourceVoc.put(0,Pair.of(0,"@NULL@"));
        sourceLookup = this.vocLookup(sourceVoc);
        targetVoc = this.loadVoc(dir + File.separator + "target.vcb");
        targetLookup = this.vocLookup(targetVoc);
        tran = this.loadTran(dir + File.separator + "output.t1.5.bin");
        this.rescale();
        this.field = field;
        this.tag = tag;
        this.qfield = qfield;
    }

    public IbmModel1(String field, String tag, String qfield,
                     ConcurrentHashMap> sourceVoc,
                     ConcurrentHashMap sourceLookup,
                     ConcurrentHashMap> targetVoc,
                     ConcurrentHashMap targetLookup,
                     ConcurrentHashMap> tran) {
        this.sourceVoc = sourceVoc;
        this.sourceLookup = sourceLookup;
        this.targetVoc = targetVoc;
        this.targetLookup = targetLookup;
        this.tran = tran;
        this.field = field;
        this.tag = tag;
        this.qfield = qfield;

    }

    public ConcurrentHashMap> loadVoc(String fileName) throws IOException {
        ConcurrentHashMap> res = new ConcurrentHashMap<>();
        BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(fileName)));
        String line = reader.readLine();
        while (line != null) {
            String[] parts = line.split("\\s");
            int id = Integer.parseInt(parts[0]);
            String voc = parts[1];
            int freq = Integer.parseInt(parts[2]);
            assert !res.containsKey(id);
            res.put(id, Pair.of(freq, voc));
            line = reader.readLine();
        }
        reader.close();
        return res;
    }

    public ConcurrentHashMap vocLookup(Map> voc) {
        ConcurrentHashMap res = new ConcurrentHashMap<>();
        for (Integer key : voc.keySet()) {
            res.put(voc.get(key).getRight(), key);
        }
        return res;
    }

    public ConcurrentHashMap> loadTran(String fileName) throws IOException {
        ConcurrentHashMap> res = new ConcurrentHashMap<>();
        DataInputStream in = new DataInputStream(new FileInputStream(fileName));
        Map bufferSourceMap = null;
        Integer bufferSourceKey = null;
        while (in.available() > 0) {
            int sourceID = in.readInt();
            assert sourceID == 0 | sourceVoc.containsKey(sourceID);
            int targetID = in.readInt();
            assert targetVoc.containsKey(targetID);
            float tranProb = in.readFloat();
            assert tranProb >= 1e-3f;
            if(bufferSourceKey!=null&&bufferSourceKey==sourceID)
                bufferSourceMap.put(targetID, tranProb);
            else{
                if (!res.containsKey(sourceID)) {
                    Map word2prob = new ConcurrentHashMap<>();
                    word2prob.put(targetID, tranProb);
                    res.put(sourceID, word2prob);
                    bufferSourceKey = sourceID;
                    bufferSourceMap = word2prob;
                } else {
                    bufferSourceKey = sourceID;
                    bufferSourceMap = res.get(sourceID);
                    bufferSourceMap.put(targetID, tranProb);
                }
            }
        }
        return res;
    }

    public void rescale() throws IOException {
        Map probSum = new HashMap<>();
        for(int sourceID: tran.keySet()){
            Map targetProbs = tran.get(sourceID);
            float adjustMult = sourceID > 0 ? (float) (1 - selfTrans) : 1.0f;
            boolean selfTranExist = false;
            for(int targetID: targetProbs.keySet()){
                float tranProb = targetProbs.get(targetID);
                String sourceWord = sourceVoc.get(sourceID).getRight();
                String targetWord = targetVoc.get(targetID).getRight();
                // should use string match, but author use id match, maybe author only use source or target vocabulary
                // to convert string to id?
//                if(sourceWord.equals(targetWord)&&sourceID!=targetID)
//                    System.out.println(sourceWord + ';' + sourceID + ';' + targetID);
                probSum.put(sourceID, probSum.getOrDefault(sourceID, 0f) + tranProb);
                tranProb *= adjustMult;
                if (sourceWord.equals(targetWord)) {
                    tranProb += selfTrans;
                    selfTranExist = true;
                }
                targetProbs.put(targetID, tranProb);
            }
            // in theroy selftrans should be add to every source word except null however when the selftrans is filtered
//            assert sourceID==0|selfTranExist;
        }
        return;
    }

    public static float calculate_score(double colProb, float totTranProb, double lambda) {
        colProb = Math.max(colProb, 1e-9f);
        double res = Math.log((1 - lambda) * totTranProb + lambda * colProb) - Math.log(lambda * colProb);
        return (float)res;
    }

    public float computeQuery(String queryWord, Map docFreq, Long docSize, double colProb) throws IOException {
        double res = 0;
        float totTranProb = 0;
        if (targetLookup.containsKey(queryWord)) {
            int queryWordId = targetLookup.get(queryWord);
            for (String docTerm : docFreq.keySet()) {
                float tranProb = 0;
                int docWordId = 0;
                if (queryWord.equals(docTerm)) {
                    tranProb = (float) selfTrans;
                    if (sourceLookup.containsKey(docTerm)) {
                        docWordId = sourceLookup.get(docTerm);
                    }
                    if (tran.containsKey(docWordId)) {
                        Map targetMap = tran.get(docWordId);
                        if (targetMap.containsKey(queryWordId)) {
                            tranProb = Math.max(targetMap.get(queryWordId), tranProb);
                        }
                    }
                } else {
                    if (sourceLookup.containsKey(docTerm)) {
                        docWordId = sourceLookup.get(docTerm);
                    }
                    if (tran.containsKey(docWordId)) {
                        Map targetMap = tran.get(docWordId);
                        if (targetMap.containsKey(queryWordId)) {
                            tranProb = targetMap.get(queryWordId);
                        }
                    }
                }
                if (tranProb >= minProb) {
                    totTranProb += tranProb * ((1.0*docFreq.get(docTerm)) / docSize);
                }
            }
        }
        return calculate_score(colProb,totTranProb,lambda);
    }

    @Override
    public float extract(DocumentContext documentContext, QueryContext queryContext) throws FileNotFoundException, IOException {
        DocumentFieldContext context = documentContext.fieldContexts.get(field);
        QueryFieldContext queryFieldContext = queryContext.fieldContexts.get(qfield);
        long docSize = context.docSize;
        long totalTermFreq = context.totalTermFreq;
        float score = 0;
        if(docSize==0) return 0;
        for (String queryToken : queryFieldContext.queryTokens) {
            double collectProb = (double) context.getCollectionFreq(queryToken) / totalTermFreq;
            score += computeQuery(queryToken, context.termFreqs, context.docSize, collectProb);
        }
        return score;
    }

    @Override
    public float postEdit(DocumentContext context, QueryContext queryContext) {
        QueryFieldContext queryFieldContext = queryContext.fieldContexts.get(qfield);
        return queryFieldContext.getSelfLog(context.docId, getName());
    }

    @Override
    public FeatureExtractor clone() {
        return new IbmModel1(field, tag, qfield, sourceVoc, sourceLookup, targetVoc, targetLookup, tran);
    }

    @Override
    public String getName() {
        String name = this.getClass().getSimpleName();
        return String.format("%s_%s_%s_%s",field, qfield, name, tag);
    }

    @Override
    public String getField() {
        return field;
    }

    @Override
    public String getQField() {
        return qfield;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy