All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.classifier.sgd.ModelDissector Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.sgd;

import com.google.common.collect.Ordering;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.Vector;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;

/**
 * Uses sample data to reverse engineer a feature-hashed model.
 *
 * The result gives approximate weights for features and interactions
 * in the original space.
 *
 * The idea is that the hashed encoders have the option of having a trace dictionary.  This
 * tells us where each feature is hashed to, or each feature/value combination in the case
 * of word-like values.  Using this dictionary, we can put values into a synthetic feature
 * vector in just the locations specified by a single feature or interaction.  Then we can
 * push this through a linear part of a model to see the contribution of that input. For
 * any generalized linear model like logistic regression, there is a linear part of the
 * model that allows this.
 *
 * What the ModelDissector does is to accept a trace dictionary and a model in an update
 * method.  It figures out the weights for the elements in the trace dictionary and stashes
 * them.  Then in a summary method, the biggest weights are returned.  This update/flush
 * style is used so that the trace dictionary doesn't have to grow to enormous levels,
 * but instead can be cleared between updates.
 */
public class ModelDissector {
  private final Map weightMap;

  public ModelDissector() {
    weightMap = new HashMap<>();
  }

  /**
   * Probes a model to determine the effect of a particular variable.  This is done
   * with the ade of a trace dictionary which has recorded the locations in the feature
   * vector that are modified by various variable values.  We can set these locations to
   * 1 and then look at the resulting score.  This tells us the weight the model places
   * on that variable.
   * @param features               A feature vector to use (destructively)
   * @param traceDictionary        A trace dictionary containing variables and what locations
   *                               in the feature vector are affected by them
   * @param learner                The model that we are probing to find weights on features
   */

  public void update(Vector features, Map> traceDictionary, AbstractVectorClassifier learner) {
    // zero out feature vector
    features.assign(0);
    for (Map.Entry> entry : traceDictionary.entrySet()) {
      // get a feature and locations where it is stored in the feature vector
      String key = entry.getKey();
      Set value = entry.getValue();

      // if we haven't looked at this feature yet
      if (!weightMap.containsKey(key)) {
        // put probe values in the feature vector
        for (Integer where : value) {
          features.set(where, 1);
        }

        // see what the model says
        Vector v = learner.classifyNoLink(features);
        weightMap.put(key, v);

        // and zero out those locations again
        for (Integer where : value) {
          features.set(where, 0);
        }
      }
    }
  }

  /**
   * Returns the n most important features with their
   * weights, most important category and the top few
   * categories that they affect.
   * @param n      How many results to return.
   * @return       A list of the top variables.
   */
  public List summary(int n) {
    Queue pq = new PriorityQueue<>();
    for (Map.Entry entry : weightMap.entrySet()) {
      pq.add(new Weight(entry.getKey(), entry.getValue()));
      while (pq.size() > n) {
        pq.poll();
      }
    }
    List r = new ArrayList<>(pq);
    Collections.sort(r, Ordering.natural().reverse());
    return r;
  }

  private static final class Category implements Comparable {
    private final int index;
    private final double weight;

    private Category(int index, double weight) {
      this.index = index;
      this.weight = weight;
    }

    @Override
    public int compareTo(Category o) {
      int r = Double.compare(Math.abs(weight), Math.abs(o.weight));
      if (r == 0) {
        if (o.index < index) {
          return -1;
        }
        if (o.index > index) {
          return 1;
        }
        return 0;
      }
      return r;
    }

    @Override
    public boolean equals(Object o) {
      if (!(o instanceof Category)) {
        return false;
      }
      Category other = (Category) o;
      return index == other.index && weight == other.weight;
    }

    @Override
    public int hashCode() {
      return RandomUtils.hashDouble(weight) ^ index;
    }

  }

  public static class Weight implements Comparable {
    private final String feature;
    private final double value;
    private final int maxIndex;
    private final List categories;

    public Weight(String feature, Vector weights) {
      this(feature, weights, 3);
    }

    public Weight(String feature, Vector weights, int n) {
      this.feature = feature;
      // pick out the weight with the largest abs value, but don't forget the sign
      Queue biggest = new PriorityQueue<>(n + 1, Ordering.natural());
      for (Vector.Element element : weights.all()) {
        biggest.add(new Category(element.index(), element.get()));
        while (biggest.size() > n) {
          biggest.poll();
        }
      }
      categories = new ArrayList<>(biggest);
      Collections.sort(categories, Ordering.natural().reverse());
      value = categories.get(0).weight;
      maxIndex = categories.get(0).index;
    }

    @Override
    public int compareTo(Weight other) {
      int r = Double.compare(Math.abs(this.value), Math.abs(other.value));
      if (r == 0) {
        return feature.compareTo(other.feature);
      }
      return r;
    }

    @Override
    public boolean equals(Object o) {
      if (!(o instanceof Weight)) {
        return false;
      }
      Weight other = (Weight) o;
      return feature.equals(other.feature)
          && value == other.value
          && maxIndex == other.maxIndex
          && categories.equals(other.categories);
    }

    @Override
    public int hashCode() {
      return feature.hashCode() ^ RandomUtils.hashDouble(value) ^ maxIndex ^ categories.hashCode();
    }

    public String getFeature() {
      return feature;
    }

    public double getWeight() {
      return value;
    }

    public double getWeight(int n) {
      return categories.get(n).weight;
    }

    public double getCategory(int n) {
      return categories.get(n).index;
    }

    public int getMaxImpact() {
      return maxIndex;
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy