org.apache.solr.ltr.model.MultipleAdditiveTreesModel Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of solr-ltr Show documentation
Show all versions of solr-ltr Show documentation
Apache Solr Learning to Rank Package
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.model;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.util.SolrPluginUtils;
/**
* A scoring model that computes scores based on the summation of multiple weighted trees.
* Example models are LambdaMART and Gradient Boosted Regression Trees (GBRT) .
*
* Example configuration:
{
"class" : "org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
"name" : "multipleadditivetreesmodel",
"features":[
{ "name" : "userTextTitleMatch"},
{ "name" : "originalScore"}
],
"params" : {
"trees" : [
{
"weight" : "1",
"root": {
"feature" : "userTextTitleMatch",
"threshold" : "0.5",
"left" : {
"value" : "-100"
},
"right" : {
"feature" : "originalScore",
"threshold" : "10.0",
"left" : {
"value" : "50"
},
"right" : {
"value" : "75"
}
}
}
},
{
"weight" : "2",
"root" : {
"value" : "-10"
}
}
]
}
}
*
* Training libraries:
*
* - RankLib
*
*
* Background reading:
*
*
*/
public class MultipleAdditiveTreesModel extends LTRScoringModel {
/**
* fname2index is filled from constructor arguments
* (that are already part of the base class hashCode)
* and therefore here it does not individually
* influence the class hashCode, equals, etc.
*/
private final HashMap fname2index;
/**
* trees is part of the LTRScoringModel params map
* and therefore here it does not individually
* influence the class hashCode, equals, etc.
*/
private List trees;
private RegressionTree createRegressionTree(Map map) {
final RegressionTree rt = new RegressionTree();
if (map != null) {
SolrPluginUtils.invokeSetters(rt, map.entrySet());
}
return rt;
}
private RegressionTreeNode createRegressionTreeNode(Map map) {
final RegressionTreeNode rtn = new RegressionTreeNode();
if (map != null) {
SolrPluginUtils.invokeSetters(rtn, map.entrySet());
}
return rtn;
}
public class RegressionTreeNode {
private static final float NODE_SPLIT_SLACK = 1E-6f;
private float value = 0f;
private String feature;
private int featureIndex = -1;
private Float threshold;
private RegressionTreeNode left;
private RegressionTreeNode right;
public void setValue(float value) {
this.value = value;
}
public void setValue(String value) {
this.value = Float.parseFloat(value);
}
public void setFeature(String feature) {
this.feature = feature;
final Integer idx = fname2index.get(this.feature);
// this happens if the tree specifies a feature that does not exist
// this could be due to lambdaSmart building off of pre-existing trees
// that use a feature that is no longer output during feature extraction
featureIndex = (idx == null) ? -1 : idx;
}
public void setThreshold(float threshold) {
this.threshold = threshold + NODE_SPLIT_SLACK;
}
public void setThreshold(String threshold) {
this.threshold = Float.parseFloat(threshold) + NODE_SPLIT_SLACK;
}
public void setLeft(Object left) {
this.left = createRegressionTreeNode((Map) left);
}
public void setRight(Object right) {
this.right = createRegressionTreeNode((Map) right);
}
public boolean isLeaf() {
return feature == null;
}
public float score(float[] featureVector) {
if (isLeaf()) {
return value;
}
// unsupported feature (tree is looking for a feature that does not exist)
if ((featureIndex < 0) || (featureIndex >= featureVector.length)) {
return 0f;
}
if (featureVector[featureIndex] <= threshold) {
return left.score(featureVector);
} else {
return right.score(featureVector);
}
}
public String explain(float[] featureVector) {
if (isLeaf()) {
return "val: " + value;
}
// unsupported feature (tree is looking for a feature that does not exist)
if ((featureIndex < 0) || (featureIndex >= featureVector.length)) {
return "'" + feature + "' does not exist in FV, Return Zero";
}
// could store extra information about how much training data supported
// each branch and report
// that here
if (featureVector[featureIndex] <= threshold) {
String rval = "'" + feature + "':" + featureVector[featureIndex] + " <= "
+ threshold + ", Go Left | ";
return rval + left.explain(featureVector);
} else {
String rval = "'" + feature + "':" + featureVector[featureIndex] + " > "
+ threshold + ", Go Right | ";
return rval + right.explain(featureVector);
}
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
if (isLeaf()) {
sb.append(value);
} else {
sb.append("(feature=").append(feature);
sb.append(",threshold=").append(threshold.floatValue()-NODE_SPLIT_SLACK);
sb.append(",left=").append(left);
sb.append(",right=").append(right);
sb.append(')');
}
return sb.toString();
}
public RegressionTreeNode() {
}
public void validate() throws ModelException {
if (isLeaf()) {
if (left != null || right != null) {
throw new ModelException("MultipleAdditiveTreesModel tree node is leaf with left="+left+" and right="+right);
}
return;
}
if (null == threshold) {
throw new ModelException("MultipleAdditiveTreesModel tree node is missing threshold");
}
if (null == left) {
throw new ModelException("MultipleAdditiveTreesModel tree node is missing left");
} else {
left.validate();
}
if (null == right) {
throw new ModelException("MultipleAdditiveTreesModel tree node is missing right");
} else {
right.validate();
}
}
}
public class RegressionTree {
private Float weight;
private RegressionTreeNode root;
public void setWeight(float weight) {
this.weight = weight;
}
public void setWeight(String weight) {
this.weight = Float.valueOf(weight);
}
public void setRoot(Object root) {
this.root = createRegressionTreeNode((Map)root);
}
public float score(float[] featureVector) {
return weight.floatValue() * root.score(featureVector);
}
public String explain(float[] featureVector) {
return root.explain(featureVector);
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
sb.append("(weight=").append(weight);
sb.append(",root=").append(root);
sb.append(")");
return sb.toString();
}
public RegressionTree() {
}
public void validate() throws ModelException {
if (weight == null) {
throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a weight");
}
if (root == null) {
throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a tree");
} else {
root.validate();
}
}
}
public void setTrees(Object trees) {
this.trees = new ArrayList();
for (final Object o : (List
© 2015 - 2025 Weber Informatics LLC | Privacy Policy