org.wikimedia.search.glent.DictionarySuggester Maven / Gradle / Ivy
package org.wikimedia.search.glent;
import static java.util.stream.Collectors.toList;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.lit;
import static org.apache.spark.sql.functions.max;
import static org.apache.spark.sql.functions.udf;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.annotation.Nullable;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.api.java.UDF2;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.types.DataTypes;
import org.wikimedia.search.glent.analysis.GlentTokenizer;
import org.wikimedia.search.glent.analysis.Tokenizers;
import lombok.AllArgsConstructor;
public class DictionarySuggester implements BiFunction, Dataset, Dataset> {
private final Instant earliestLegalTs;
public DictionarySuggester(Instant earliestLegalTs) {
this.earliestLegalTs = earliestLegalTs;
}
/**
*
* @param dfLog CirrusSearch query logs since previous run of suggester
* @param dfOld Previous run of DictionarySuggestor
* @return Dictionary based query suggestions
*/
public Dataset apply(Dataset dfLog, Dataset dfOld) {
// Verify we have the expected named columns and rename to match our outputs.
// A more generalized way to assert the shape of inputs would be preferred.
dfLog = dfLog.select(
"query", "queryNorm", "lang", "wikiid", "ts", "hitsTotal")
.withColumnRenamed("hitsTotal", "queryHitsTotal");
// Verify the same on dfOld
dfOld = dfOld.select(
"query", "queryNorm", "lang", "wikiid", "ts", "queryHitsTotal");
// The query suggestion algorithm runs in isolation, there is no cross-query interaction,
// and the historical dataset is very small compared to the logs we are ingesting. To
// simplify operations such as updating the algorithm always run the historical suggestions
// through the current version of the suggester.
Dataset df = legalReqs(dfOld.union(dfLog));
df = findM2QueryMatch(df);
return reshape(df);
}
static List tokenizeString(String query, String lang) {
GlentTokenizer tokenizer;
if (null == lang) {
return Collections.emptyList();
} else {
switch (lang) {
case "ko":
tokenizer = Tokenizers.korean();
break;
case "ja":
tokenizer = Tokenizers.japanese();
break;
case "zh":
tokenizer = Tokenizers.simplifiedChinese();
break;
default:
return Collections.emptyList();
}
}
return tokenizer.tokenize(query, " ");
}
static UDF2 buildSuggsM2Udf() {
return (query, lang) -> {
M2Resources resources = M2Resources.getInstance();
return buildSuggsM2(query, lang,
resources.confusions(), resources.wordFreq().get(lang));
};
}
@AllArgsConstructor
static class TokenConfusion {
private static final Pattern CJK_CHAR_PAT =
Pattern.compile("\\p{IsHan}|\\p{IsHangul}|\\p{IsHiragana}|\\p{IsKatakana}|[\\u3099-\\u309F\\u30FC-\\u30FF\\uFF70]");
final String token;
@Nullable
private final List confusions;
public boolean isSingleCJK() {
if (token.length() > 1) {
return false;
}
Matcher matcher = CJK_CHAR_PAT.matcher(token);
return matcher.matches();
}
List confusions() {
return confusions == null ? Collections.emptyList() : confusions;
}
}
/**
* identify candidates for dym using tokenizer + dictionary + confusion matrix.
*
* @param query queryNorm value
* @param lang lang value
* @param confusion map from char seq to list of replacement char sequences
* @param dictionary per-lang map from token to frequency count
* @return dym possible suggestion
*
*/
static String buildSuggsM2(String query, String lang,
Map> confusion, Map dictionary) {
if (query.isEmpty() || lang.isEmpty() || dictionary == null) {
return "";
}
List tokens = tokenizeString(query, lang);
if (tokens.isEmpty()) {
return "";
}
List suggCM = new ArrayList<>();
for (String token : tokens) {
List confusions = null;
if (token.length() == 1) {
confusions = confusion.get(token);
}
suggCM.add(new TokenConfusion(token, confusions));
}
String dym = runBuildBestConfusionPiece(suggCM, dictionary);
String queryOrig = String.join("", tokens);
return dym.equals(queryOrig) ? "" : dym;
}
/**
* function that runs buildBestConfusionPiece.
*
* @param suggCM List of List of confusion values
* @param dictionary Word frequency counts
* @return possible "did you mean" suggestion
*
*/
@SuppressWarnings("ModifiedControlVariable")
static String runBuildBestConfusionPiece(List suggCM,
Map dictionary) {
StringBuilder sb = new StringBuilder();
int end = suggCM.size() - 1;
for (int i = 0; i <= end; i++) {
// Find continuous run of single cjk characters
int j = i;
for (; j <= end; j++) {
if (!suggCM.get(j).isSingleCJK()) {
break;
}
}
// We need at least two sequential cjk tokens
if (j - i < 2) {
sb.append(suggCM.get(i).token);
} else {
sb.append(buildBestConfusionPiece(suggCM.subList(i, j), dictionary));
// Continue iteration with the non-single cjk that ended our window
i = j - 1;
}
}
return sb.toString();
}
/**
* build suggestion based on list of tokens and confusions.
*
* @param suggCM list of single character tokens and their confusions
* @param dictionary word frequency statistics for choosing best
* @return sugg possible suggestion
*
*/
static String buildBestConfusionPiece(List suggCM,
Map dictionary) {
List tokens = suggCM.stream().map(tc -> tc.token).collect(toList());
List suggList = new ArrayList<>();
suggList.add(String.join("", tokens));
for (int i = 0; i < suggCM.size(); i++) {
String left = String.join("", tokens.subList(0, i));
String right = String.join("", tokens.subList(i + 1, suggCM.size()));
for (String c : suggCM.get(i).confusions()) {
suggList.add(left + c + right);
}
}
return suggList.stream()
.filter(dictionary::containsKey)
.max(Comparator.comparingInt(dictionary::get))
.orElseGet(() -> suggList.get(0));
}
/**
* find query match based on M2.
*
* @param dfUserQuery user query dataframe
* @return dataframe with possible suggestions that match user query
*
*/
static Dataset findM2QueryMatch(Dataset dfUserQuery) {
UserDefinedFunction buildSuggsM2Udf = udf(buildSuggsM2Udf(), DataTypes.StringType);
dfUserQuery = dfUserQuery.withColumn("dym",
buildSuggsM2Udf.apply(col("queryNorm"), col("lang")));
return dfUserQuery
.where(col("dym").notEqual(""))
.where(col("dym").notEqual(col("queryNorm")))
.distinct();
}
/**
* Reshape for output to shared suggestions table
*
* Flattens multiple occurances of the same suggestion and renames columns to match
* our outputs. The fields of the shared format not used here are set to 0.
*
* @param df M2 Suggestions dataframe
* @return dataframe of "did you mean" results
*/
static Dataset reshape(Dataset df) {
return df
.groupBy("query", "dym", "wikiid", "lang")
.agg(
max("ts").alias("ts"),
max("queryHitsTotal").alias("queryHitsTotal"))
.withColumn("q1q2EditDist", lit(0F))
.withColumn("dymHitsTotal", lit(0))
.withColumn("suggCount", lit(0));
}
/**
* Removes dataframe entries that have timestamp earlier than required by legal.
*
* @param df M1Prep dataframe
* @return dataframe that satisfies legal requirements
*
*/
Dataset legalReqs(Dataset df) {
return df.where(col("ts").geq(earliestLegalTs.getEpochSecond()));
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy