org.fbk.cit.hlt.thewikimachine.index.PageIncomingOutgoingWeightedSearcher Maven / Gradle / Ivy
/*
* Copyright (2013) Fondazione Bruno Kessler (http://www.fbk.eu/)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.fbk.cit.hlt.thewikimachine.index;
import org.apache.commons.cli.*;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermDocs;
import org.fbk.cit.hlt.core.math.Node;
import org.fbk.cit.hlt.thewikimachine.index.util.AbstractSearcher;
import org.fbk.cit.hlt.thewikimachine.util.CharacterTable;
import org.fbk.cit.hlt.thewikimachine.util.StringTable;
import org.fbk.cit.hlt.thewikimachine.xmldump.AbstractWikipediaExtractor;
import java.io.*;
import java.text.DecimalFormat;
import java.util.*;
import java.util.regex.Pattern;
/**
* Created with IntelliJ IDEA.
* User: giuliano
* Date: 1/24/13
* Time: 11:37 PM
* To change this template use File | Settings | File Templates.
*/
public class PageIncomingOutgoingWeightedSearcher extends AbstractSearcher {
/**
* Define a static logger variable so that it references the
* Logger instance named PageIncomingOutgoingWeightedSearcher
.
*/
static Logger logger = Logger.getLogger(PageIncomingOutgoingWeightedSearcher.class.getName());
public static final int DEFAULT_MIN_FREQ = 1000;
public static final boolean DEFAULT_THREAD_SAFE = false;
protected static DecimalFormat df = new DecimalFormat("###,###,###,###");
protected static DecimalFormat vf = new DecimalFormat("###,###,###,##0.000");
private static DecimalFormat tf = new DecimalFormat("000,000,000.#");
public static final int BOW_INDEX = 1;
public static final int LS_INDEX = 0;
private static Pattern tabPattern = Pattern.compile(StringTable.HORIZONTAL_TABULATION);
protected boolean threadSafe;
private Map cache;
private Term keyTerm;
public PageIncomingOutgoingWeightedSearcher(String indexName) throws IOException {
this(indexName, false);
}
public PageIncomingOutgoingWeightedSearcher(String indexName, boolean threadSafe) throws IOException {
super(indexName);
this.threadSafe = threadSafe;
keyTerm = new Term(PageVectorIndexer.PAGE_FIELD_NAME, "");
logger.debug(keyTerm);
logger.trace(toString(10));
}
public Map getCache() {
return cache;
}
public void loadCache(String name) throws IOException {
loadCache(new File(name));
}
public void loadCache(String name, int minFreq) throws IOException {
loadCache(new File(name), minFreq);
}
public void loadCache(File f) throws IOException {
loadCache(f, DEFAULT_MIN_FREQ);
}
public void loadCache(File f, int minFreq) throws IOException {
logger.info("loading cache from " + f + " (freq>" + minFreq + ")...");
long begin = System.nanoTime();
if (threadSafe) {
logger.info(this.getClass().getName() + "'s cache is thread safe");
cache = Collections.synchronizedMap(new HashMap());
}
else {
logger.warn(this.getClass().getName() + "'s cache isn't thread safe");
cache = new HashMap();
}
LineNumberReader lnr = new LineNumberReader(new InputStreamReader(new FileInputStream(f), "UTF-8"));
String line;
int i = 1;
String[] t;
int freq = 0;
Node[] vector;
Document doc;
TermDocs termDocs;
while ((line = lnr.readLine()) != null) {
t = tabPattern.split(line);
if (t.length == 2) {
freq = Integer.parseInt(t[0]);
if (freq < minFreq) {
break;
}
termDocs = indexReader.termDocs(keyTerm.createTerm(t[1]));
if (termDocs.next()) {
doc = indexReader.document(termDocs.doc());
vector = toNode(doc.getBinaryValue(PageIncomingOutgoingWeightedIndexer.VECTOR_FIELD_NAME));
cache.put(t[1], vector);
}
}
if ((i % notificationPoint) == 0) {
//System.out.print(CharacterTable.FULL_STOP);
logger.debug(i + " keys read (" + cache.size() + ") " + new Date());
}
i++;
}
System.out.print(CharacterTable.LINE_FEED);
lnr.close();
long end = System.nanoTime();
logger.info(df.format(cache.size()) + " (" + df.format(indexReader.numDocs()) + ") keys cached in " + tf.format(end - begin) + " ns");
}
public static Node[] toNode(byte[] byteArray) throws IOException {
//logger.debug("reading node from " + Arrays.toString(byteArray));
ByteArrayInputStream byteStream = new ByteArrayInputStream(byteArray);
DataInputStream dataStream = new DataInputStream(byteStream);
int m = dataStream.readInt();
//logger.debug("m " + m);
Node[] node = new Node[m];
for (int j = 0; j < m; j++) {
node[j] = new Node(dataStream.readInt(), dataStream.readDouble());
//logger.debug(j + "\t" + node[j]);
}
return node;
}
public Node[] search(String page) {
//logger.debug("searching " + page + "...");
//long begin = 0, end = 0;
//begin = System.nanoTime();
Node[] vector = null;
if (cache != null) {
vector = cache.get(page);
}
//end = System.nanoTime();
if (vector != null) {
//logger.debug("found in cache in " + tf.format(end - begin) + " ns");
return vector;
}
try {
//begin = System.nanoTime();
TermDocs termDocs = indexReader.termDocs(keyTerm.createTerm(page));
//end = System.nanoTime();
//logger.debug("found in index in " + tf.format(end - begin) + " ns");
if (termDocs.next()) {
//begin = System.nanoTime();
Document doc = indexReader.document(termDocs.doc());
vector = toNode(doc.getBinaryValue(PageIncomingOutgoingWeightedIndexer.VECTOR_FIELD_NAME));
//end = System.nanoTime();
//logger.debug(termDocs.freq() + " deserialized in " + tf.format(end - begin) + " ns");
return vector;
}
} catch (Exception e) {
return new Node[0];
}
return new Node[0];
}
public void interactive() throws Exception {
InputStreamReader indexReader = null;
BufferedReader myInput = null;
long begin = 0, end = 0;
while (true) {
System.out.println("\nPlease write a query and type to continue (CTRL C to exit):");
indexReader = new InputStreamReader(System.in);
myInput = new BufferedReader(indexReader);
String query = myInput.readLine().toString();
String[] s = query.split("\t");
if (s.length == 1) {
begin = System.nanoTime();
Node[] node = search(query);
end = System.nanoTime();
logger.info(query + "\t<" + nodeToString(node) + ">\t" + tf.format(end - begin) + " ns");
}
else if (s.length == 2) {
begin = System.nanoTime();
double sim = compare(s[0], s[1]);
end = System.nanoTime();
logger.info(query);
logger.info("compare\t" + vf.format(sim));
begin = System.nanoTime();
double dot = dot(s[0], s[1]);
end = System.nanoTime();
logger.info(query);
logger.info("dot\t" + vf.format(dot));
//double normalizedCombo = combo / Math.sqrt(compare(s[0], s[0]) * compare(s[1], s[1]));
//logger.info("norm\t" + vf.format(normalizedCombo));
}
}
}
public double dot(String p1, String p2) {
Node[] vec0 = search(p1);
Node[] vec1 = search(p2);
Node.normalize(vec0);
Node.normalize(vec1);
//logger.debug("bow = " + d);
return Node.dot(vec0, vec1);
}
public double compare(String p1, String p2) {
double d01, d00, d11;
Node[] vec0 = search(p1);
Node[] vec1 = search(p2);
d00 = Node.dot(vec0, vec0);
//logger.debug("d00 = " + d00);
d11 = Node.dot(vec1, vec1);
//logger.debug("d11 = " + d11);
d01 = Node.dot(vec0, vec1);
//logger.debug("d01 = " + d01);
double d = d01 / Math.sqrt(d00 * d11);
//logger.debug("bow = " + d);
return d;
}
static public String nodeToString(Node[] node) {
StringBuilder sb = new StringBuilder();
if (node.length > 0) {
sb.append(node[0].index + ":" + node[0].value);
}
for (int i = 1; i < node.length; i++) {
sb.append(" " + node[i].index + ":" + node[i].value);
}
return sb.toString();
}
public static void main(String args[]) throws Exception {
String logConfig = System.getProperty("log-config");
if (logConfig == null) {
logConfig = "configuration/log-config.txt";
}
PropertyConfigurator.configure(logConfig);
Options options = new Options();
try {
Option indexNameOpt = OptionBuilder.withArgName("index").hasArg().withDescription("open an index with the specified name").isRequired().withLongOpt("index").create("i");
Option interactiveModeOpt = OptionBuilder.withArgName("interactive-mode").withDescription("enter in the interactive mode").withLongOpt("interactive-mode").create("t");
Option searchOpt = OptionBuilder.withArgName("search").hasArg().withDescription("search for the specified key").withLongOpt("search").create("s");
Option freqFileOpt = OptionBuilder.withArgName("key-freq").hasArg().withDescription("read the keys' frequencies from the specified file").withLongOpt("key-freq").create("f");
//Option keyFieldNameOpt = OptionBuilder.withArgName("key-field-name").hasArg().withDescription("use the specified name for the field key").withLongOpt("key-field-name").create("k");
//Option valueFieldNameOpt = OptionBuilder.withArgName("value-field-name").hasArg().withDescription("use the specified name for the field value").withLongOpt("value-field-name").create("v");
Option minimumKeyFreqOpt = OptionBuilder.withArgName("minimum-freq").hasArg().withDescription("minimum key frequency of cached values (default is " + DEFAULT_MIN_FREQ + ")").withLongOpt("minimum-freq").create("m");
Option notificationPointOpt = OptionBuilder.withArgName("int").hasArg().withDescription("receive notification every n pages (default is " + AbstractWikipediaExtractor.DEFAULT_NOTIFICATION_POINT + ")").withLongOpt("notification-point").create("b");
options.addOption("h", "help", false, "print this message");
options.addOption("v", "version", false, "output version information and exit");
options.addOption(indexNameOpt);
options.addOption(interactiveModeOpt);
options.addOption(searchOpt);
options.addOption(freqFileOpt);
//options.addOption(keyFieldNameOpt);
//options.addOption(valueFieldNameOpt);
options.addOption(minimumKeyFreqOpt);
options.addOption(notificationPointOpt);
CommandLineParser parser = new PosixParser();
CommandLine line = parser.parse(options, args);
if (line.hasOption("help") || line.hasOption("version")) {
throw new ParseException("");
}
int minFreq = DEFAULT_MIN_FREQ;
if (line.hasOption("minimum-freq")) {
minFreq = Integer.parseInt(line.getOptionValue("minimum-freq"));
}
int notificationPoint = AbstractWikipediaExtractor.DEFAULT_NOTIFICATION_POINT;
if (line.hasOption("notification-point")) {
notificationPoint = Integer.parseInt(line.getOptionValue("notification-point"));
}
PageIncomingOutgoingWeightedSearcher pageVectorSearcher = new PageIncomingOutgoingWeightedSearcher(line.getOptionValue("index"));
pageVectorSearcher.setNotificationPoint(notificationPoint);
if (line.hasOption("key-freq")) {
pageVectorSearcher.loadCache(line.getOptionValue("key-freq"), minFreq);
}
if (line.hasOption("search")) {
logger.debug("searching " + line.getOptionValue("search") + "...");
Node[] result = pageVectorSearcher.search(line.getOptionValue("search"));
logger.info(result);
}
if (line.hasOption("interactive-mode")) {
pageVectorSearcher.interactive();
}
} catch (ParseException e) {
// oops, something went wrong
if (e.getMessage().length() > 0) {
System.out.println("Parsing failed: " + e.getMessage() + "\n");
}
HelpFormatter formatter = new HelpFormatter();
formatter.printHelp(400, "java -cp dist/thewikimachine.jar org.fbk.cit.hlt.thewikimachine.index.PageIncomingOutgoingWeightedSearcher", "\n", options, "\n", true);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy