org.apache.mahout.classifier.sgd.ModelDissector Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mahout-mr Show documentation
Show all versions of mahout-mr Show documentation
Scalable machine learning libraries
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.sgd;
import com.google.common.collect.Ordering;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.Vector;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;
/**
* Uses sample data to reverse engineer a feature-hashed model.
*
* The result gives approximate weights for features and interactions
* in the original space.
*
* The idea is that the hashed encoders have the option of having a trace dictionary. This
* tells us where each feature is hashed to, or each feature/value combination in the case
* of word-like values. Using this dictionary, we can put values into a synthetic feature
* vector in just the locations specified by a single feature or interaction. Then we can
* push this through a linear part of a model to see the contribution of that input. For
* any generalized linear model like logistic regression, there is a linear part of the
* model that allows this.
*
* What the ModelDissector does is to accept a trace dictionary and a model in an update
* method. It figures out the weights for the elements in the trace dictionary and stashes
* them. Then in a summary method, the biggest weights are returned. This update/flush
* style is used so that the trace dictionary doesn't have to grow to enormous levels,
* but instead can be cleared between updates.
*/
public class ModelDissector {
private final Map weightMap;
public ModelDissector() {
weightMap = new HashMap<>();
}
/**
* Probes a model to determine the effect of a particular variable. This is done
* with the ade of a trace dictionary which has recorded the locations in the feature
* vector that are modified by various variable values. We can set these locations to
* 1 and then look at the resulting score. This tells us the weight the model places
* on that variable.
* @param features A feature vector to use (destructively)
* @param traceDictionary A trace dictionary containing variables and what locations
* in the feature vector are affected by them
* @param learner The model that we are probing to find weights on features
*/
public void update(Vector features, Map> traceDictionary, AbstractVectorClassifier learner) {
// zero out feature vector
features.assign(0);
for (Map.Entry> entry : traceDictionary.entrySet()) {
// get a feature and locations where it is stored in the feature vector
String key = entry.getKey();
Set value = entry.getValue();
// if we haven't looked at this feature yet
if (!weightMap.containsKey(key)) {
// put probe values in the feature vector
for (Integer where : value) {
features.set(where, 1);
}
// see what the model says
Vector v = learner.classifyNoLink(features);
weightMap.put(key, v);
// and zero out those locations again
for (Integer where : value) {
features.set(where, 0);
}
}
}
}
/**
* Returns the n most important features with their
* weights, most important category and the top few
* categories that they affect.
* @param n How many results to return.
* @return A list of the top variables.
*/
public List summary(int n) {
Queue pq = new PriorityQueue<>();
for (Map.Entry entry : weightMap.entrySet()) {
pq.add(new Weight(entry.getKey(), entry.getValue()));
while (pq.size() > n) {
pq.poll();
}
}
List r = new ArrayList<>(pq);
Collections.sort(r, Ordering.natural().reverse());
return r;
}
private static final class Category implements Comparable {
private final int index;
private final double weight;
private Category(int index, double weight) {
this.index = index;
this.weight = weight;
}
@Override
public int compareTo(Category o) {
int r = Double.compare(Math.abs(weight), Math.abs(o.weight));
if (r == 0) {
if (o.index < index) {
return -1;
}
if (o.index > index) {
return 1;
}
return 0;
}
return r;
}
@Override
public boolean equals(Object o) {
if (!(o instanceof Category)) {
return false;
}
Category other = (Category) o;
return index == other.index && weight == other.weight;
}
@Override
public int hashCode() {
return RandomUtils.hashDouble(weight) ^ index;
}
}
public static class Weight implements Comparable {
private final String feature;
private final double value;
private final int maxIndex;
private final List categories;
public Weight(String feature, Vector weights) {
this(feature, weights, 3);
}
public Weight(String feature, Vector weights, int n) {
this.feature = feature;
// pick out the weight with the largest abs value, but don't forget the sign
Queue biggest = new PriorityQueue<>(n + 1, Ordering.natural());
for (Vector.Element element : weights.all()) {
biggest.add(new Category(element.index(), element.get()));
while (biggest.size() > n) {
biggest.poll();
}
}
categories = new ArrayList<>(biggest);
Collections.sort(categories, Ordering.natural().reverse());
value = categories.get(0).weight;
maxIndex = categories.get(0).index;
}
@Override
public int compareTo(Weight other) {
int r = Double.compare(Math.abs(this.value), Math.abs(other.value));
if (r == 0) {
return feature.compareTo(other.feature);
}
return r;
}
@Override
public boolean equals(Object o) {
if (!(o instanceof Weight)) {
return false;
}
Weight other = (Weight) o;
return feature.equals(other.feature)
&& value == other.value
&& maxIndex == other.maxIndex
&& categories.equals(other.categories);
}
@Override
public int hashCode() {
return feature.hashCode() ^ RandomUtils.hashDouble(value) ^ maxIndex ^ categories.hashCode();
}
public String getFeature() {
return feature;
}
public double getWeight() {
return value;
}
public double getWeight(int n) {
return categories.get(n).weight;
}
public double getCategory(int n) {
return categories.get(n).index;
}
public int getMaxImpact() {
return maxIndex;
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy