All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.solr.ltr.model.MultipleAdditiveTreesModel Maven / Gradle / Ivy

There is a newer version: 9.7.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.solr.ltr.model;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.util.SolrPluginUtils;

/**
 * A scoring model that computes scores based on the summation of multiple weighted trees.
 * Example models are LambdaMART and Gradient Boosted Regression Trees (GBRT) .
 * 

* Example configuration:

{
   "class" : "org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
   "name" : "multipleadditivetreesmodel",
   "features":[
       { "name" : "userTextTitleMatch"},
       { "name" : "originalScore"}
   ],
   "params" : {
       "trees" : [
           {
               "weight" : "1",
               "root": {
                   "feature" : "userTextTitleMatch",
                   "threshold" : "0.5",
                   "left" : {
                       "value" : "-100"
                   },
                   "right" : {
                       "feature" : "originalScore",
                       "threshold" : "10.0",
                       "left" : {
                           "value" : "50"
                       },
                       "right" : {
                           "value" : "75"
                       }
                   }
               }
           },
           {
               "weight" : "2",
               "root" : {
                   "value" : "-10"
               }
           }
       ]
   }
}
*

* Training libraries: *

*

* Background reading: *

* */ public class MultipleAdditiveTreesModel extends LTRScoringModel { /** * fname2index is filled from constructor arguments * (that are already part of the base class hashCode) * and therefore here it does not individually * influence the class hashCode, equals, etc. */ private final HashMap fname2index; /** * trees is part of the LTRScoringModel params map * and therefore here it does not individually * influence the class hashCode, equals, etc. */ private List trees; private RegressionTree createRegressionTree(Map map) { final RegressionTree rt = new RegressionTree(); if (map != null) { SolrPluginUtils.invokeSetters(rt, map.entrySet()); } return rt; } private RegressionTreeNode createRegressionTreeNode(Map map) { final RegressionTreeNode rtn = new RegressionTreeNode(); if (map != null) { SolrPluginUtils.invokeSetters(rtn, map.entrySet()); } return rtn; } public class RegressionTreeNode { private static final float NODE_SPLIT_SLACK = 1E-6f; private float value = 0f; private String feature; private int featureIndex = -1; private Float threshold; private RegressionTreeNode left; private RegressionTreeNode right; public void setValue(float value) { this.value = value; } public void setValue(String value) { this.value = Float.parseFloat(value); } public void setFeature(String feature) { this.feature = feature; final Integer idx = fname2index.get(this.feature); // this happens if the tree specifies a feature that does not exist // this could be due to lambdaSmart building off of pre-existing trees // that use a feature that is no longer output during feature extraction featureIndex = (idx == null) ? -1 : idx; } public void setThreshold(float threshold) { this.threshold = threshold + NODE_SPLIT_SLACK; } public void setThreshold(String threshold) { this.threshold = Float.parseFloat(threshold) + NODE_SPLIT_SLACK; } public void setLeft(Object left) { this.left = createRegressionTreeNode((Map) left); } public void setRight(Object right) { this.right = createRegressionTreeNode((Map) right); } public boolean isLeaf() { return feature == null; } public float score(float[] featureVector) { if (isLeaf()) { return value; } // unsupported feature (tree is looking for a feature that does not exist) if ((featureIndex < 0) || (featureIndex >= featureVector.length)) { return 0f; } if (featureVector[featureIndex] <= threshold) { return left.score(featureVector); } else { return right.score(featureVector); } } public String explain(float[] featureVector) { if (isLeaf()) { return "val: " + value; } // unsupported feature (tree is looking for a feature that does not exist) if ((featureIndex < 0) || (featureIndex >= featureVector.length)) { return "'" + feature + "' does not exist in FV, Return Zero"; } // could store extra information about how much training data supported // each branch and report // that here if (featureVector[featureIndex] <= threshold) { String rval = "'" + feature + "':" + featureVector[featureIndex] + " <= " + threshold + ", Go Left | "; return rval + left.explain(featureVector); } else { String rval = "'" + feature + "':" + featureVector[featureIndex] + " > " + threshold + ", Go Right | "; return rval + right.explain(featureVector); } } @Override public String toString() { final StringBuilder sb = new StringBuilder(); if (isLeaf()) { sb.append(value); } else { sb.append("(feature=").append(feature); sb.append(",threshold=").append(threshold.floatValue()-NODE_SPLIT_SLACK); sb.append(",left=").append(left); sb.append(",right=").append(right); sb.append(')'); } return sb.toString(); } public RegressionTreeNode() { } public void validate() throws ModelException { if (isLeaf()) { if (left != null || right != null) { throw new ModelException("MultipleAdditiveTreesModel tree node is leaf with left="+left+" and right="+right); } return; } if (null == threshold) { throw new ModelException("MultipleAdditiveTreesModel tree node is missing threshold"); } if (null == left) { throw new ModelException("MultipleAdditiveTreesModel tree node is missing left"); } else { left.validate(); } if (null == right) { throw new ModelException("MultipleAdditiveTreesModel tree node is missing right"); } else { right.validate(); } } } public class RegressionTree { private Float weight; private RegressionTreeNode root; public void setWeight(float weight) { this.weight = weight; } public void setWeight(String weight) { this.weight = Float.valueOf(weight); } public void setRoot(Object root) { this.root = createRegressionTreeNode((Map)root); } public float score(float[] featureVector) { return weight.floatValue() * root.score(featureVector); } public String explain(float[] featureVector) { return root.explain(featureVector); } @Override public String toString() { final StringBuilder sb = new StringBuilder(); sb.append("(weight=").append(weight); sb.append(",root=").append(root); sb.append(")"); return sb.toString(); } public RegressionTree() { } public void validate() throws ModelException { if (weight == null) { throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a weight"); } if (root == null) { throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a tree"); } else { root.validate(); } } } public void setTrees(Object trees) { this.trees = new ArrayList(); for (final Object o : (List) trees) { final RegressionTree rt = createRegressionTree((Map) o); this.trees.add(rt); } } public MultipleAdditiveTreesModel(String name, List features, List norms, String featureStoreName, List allFeatures, Map params) { super(name, features, norms, featureStoreName, allFeatures, params); fname2index = new HashMap(); for (int i = 0; i < features.size(); ++i) { final String key = features.get(i).getName(); fname2index.put(key, i); } } @Override protected void validate() throws ModelException { super.validate(); if (trees == null) { throw new ModelException("no trees declared for model "+name); } for (RegressionTree tree : trees) { tree.validate(); } } @Override public float score(float[] modelFeatureValuesNormalized) { float score = 0; for (final RegressionTree t : trees) { score += t.score(modelFeatureValuesNormalized); } return score; } // ///////////////////////////////////////// // produces a string that looks like: // 40.0 = multipleadditivetreesmodel [ org.apache.solr.ltr.model.MultipleAdditiveTreesModel ] // model applied to // features, sum of: // 50.0 = tree 0 | 'matchedTitle':1.0 > 0.500001, Go Right | // 'this_feature_doesnt_exist' does not // exist in FV, Go Left | val: 50.0 // -10.0 = tree 1 | val: -10.0 @Override public Explanation explain(LeafReaderContext context, int doc, float finalScore, List featureExplanations) { final float[] fv = new float[featureExplanations.size()]; int index = 0; for (final Explanation featureExplain : featureExplanations) { fv[index] = featureExplain.getValue(); index++; } final List details = new ArrayList<>(); index = 0; for (final RegressionTree t : trees) { final float score = t.score(fv); final Explanation p = Explanation.match(score, "tree " + index + " | " + t.explain(fv)); details.add(p); index++; } return Explanation.match(finalScore, toString() + " model applied to features, sum of:", details); } @Override public String toString() { final StringBuilder sb = new StringBuilder(getClass().getSimpleName()); sb.append("(name=").append(getName()); sb.append(",trees=["); for (int ii = 0; ii < trees.size(); ++ii) { if (ii>0) { sb.append(','); } sb.append(trees.get(ii)); } sb.append("])"); return sb.toString(); } }