org.apache.solr.ltr.model.MultipleAdditiveTreesModel Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of solr-ltr Show documentation
Show all versions of solr-ltr Show documentation
Apache Solr Learning to Rank Package
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.model;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.FeatureException;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.util.SolrPluginUtils;
/**
* A scoring model that computes scores based on the summation of multiple weighted trees. Example
* models are LambdaMART and Gradient Boosted Regression Trees (GBRT) .
*
* Example configuration:
*
*
* {
* "class" : "org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
* "name" : "multipleadditivetreesmodel",
* "features":[
* { "name" : "userTextTitleMatch"},
* { "name" : "originalScore"}
* ],
* "params" : {
* "trees" : [
* {
* "weight" : "1",
* "root": {
* "feature" : "userTextTitleMatch",
* "threshold" : "0.5",
* "left" : {
* "value" : "-100"
* },
* "right" : {
* "feature" : "originalScore",
* "threshold" : "10.0",
* "left" : {
* "value" : "50"
* },
* "right" : {
* "value" : "75"
* }
* }
* }
* },
* {
* "weight" : "2",
* "root" : {
* "value" : "-10"
* }
* }
* ]
* }
* }
*
*
* Training libraries:
*
*
* - RankLib
*
*
* Background reading:
*
*
*
*
*/
public class MultipleAdditiveTreesModel extends LTRScoringModel {
/**
* fname2index is filled from constructor arguments (that are already part of the base class
* hashCode) and therefore here it does not individually influence the class hashCode, equals,
* etc.
*/
private final HashMap fname2index;
/**
* trees is part of the LTRScoringModel params map and therefore here it does not individually
* influence the class hashCode, equals, etc.
*/
private List trees;
private boolean isNullSameAsZero = true;
private RegressionTree createRegressionTree(Map map) {
final RegressionTree rt = new RegressionTree();
if (map != null) {
SolrPluginUtils.invokeSetters(rt, map.entrySet());
}
return rt;
}
private RegressionTreeNode createRegressionTreeNode(Map map) {
final RegressionTreeNode rtn = new RegressionTreeNode();
if (map != null) {
SolrPluginUtils.invokeSetters(rtn, map.entrySet());
}
return rtn;
}
public void setIsNullSameAsZero(boolean nullSameAsZero) {
isNullSameAsZero = nullSameAsZero;
}
public class RegressionTreeNode {
private static final float NODE_SPLIT_SLACK = 1E-6f;
private float value = 0f;
private String feature;
private int featureIndex = -1;
private Float threshold;
private RegressionTreeNode left;
private RegressionTreeNode right;
private String missing;
public void setValue(float value) {
this.value = value;
}
public void setValue(String value) {
this.value = Float.parseFloat(value);
}
public void setMissing(String direction) {
this.missing = direction;
}
public void setFeature(String feature) {
this.feature = feature;
final Integer idx = fname2index.get(this.feature);
// this happens if the tree specifies a feature that does not exist
// this could be due to lambdaSmart building off of pre-existing trees
// that use a feature that is no longer output during feature extraction
featureIndex = (idx == null) ? -1 : idx;
}
public void setThreshold(float threshold) {
this.threshold = threshold + NODE_SPLIT_SLACK;
}
public void setThreshold(String threshold) {
this.threshold = Float.parseFloat(threshold) + NODE_SPLIT_SLACK;
}
@SuppressWarnings({"unchecked"})
public void setLeft(Object left) {
this.left = createRegressionTreeNode((Map) left);
}
@SuppressWarnings({"unchecked"})
public void setRight(Object right) {
this.right = createRegressionTreeNode((Map) right);
}
public boolean isLeaf() {
return feature == null;
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder();
if (isLeaf()) {
sb.append(value);
} else {
sb.append("(feature=").append(feature);
sb.append(",threshold=").append(threshold.floatValue() - NODE_SPLIT_SLACK);
if (missing != null) {
sb.append(",missing=").append(missing);
}
sb.append(",left=").append(left);
sb.append(",right=").append(right);
sb.append(')');
}
return sb.toString();
}
public RegressionTreeNode() {}
}
public class RegressionTree {
private Float weight;
private RegressionTreeNode root;
public void setWeight(float weight) {
this.weight = weight;
}
public void setWeight(String weight) {
this.weight = Float.valueOf(weight);
}
@SuppressWarnings({"unchecked"})
public void setRoot(Object root) {
this.root = createRegressionTreeNode((Map) root);
}
public float score(float[] featureVector) {
if (isNullSameAsZero) {
return weight.floatValue() * scoreNode(featureVector, root);
} else {
return weight.floatValue() * scoreNodeWithNullSupport(featureVector, root);
}
}
public String explain(float[] featureVector) {
return explainNode(featureVector, root);
}
@Override
public String toString() {
return "(weight=" + weight + ",root=" + root + ")";
}
public RegressionTree() {}
public void validate() throws ModelException {
if (weight == null) {
throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a weight");
}
if (root == null) {
throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a tree");
} else {
validateNode(root);
}
}
}
@SuppressWarnings({"unchecked"})
public void setTrees(Object trees) {
this.trees = new ArrayList();
for (final Object o : (List
© 2015 - 2024 Weber Informatics LLC | Privacy Policy