org.apache.mahout.math.stats.GroupedOnlineAuc Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mahout-mr Show documentation
Show all versions of mahout-mr Show documentation
Scalable machine learning libraries
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math.stats;
import com.google.common.collect.Maps;
import org.apache.mahout.classifier.sgd.PolymorphicWritable;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Map;
/**
* Implements a variant on AUC where the result returned is an average of several AUC measurements
* made on sub-groups of the overall data. Controlling for the grouping factor allows the effects
* of the grouping factor on the model to be ignored. This is useful, for instance, when using a
* classifier as a click prediction engine. In that case you want AUC to refer only to the ranking
* of items for a particular user, not to the discrimination of users from each other. Grouping by
* user (or user cluster) helps avoid optimizing for the wrong quality.
*/
public class GroupedOnlineAuc implements OnlineAuc {
private final Map map = Maps.newHashMap();
private GlobalOnlineAuc.ReplacementPolicy policy;
private int windowSize;
@Override
public double addSample(int category, String groupKey, double score) {
if (groupKey == null) {
addSample(category, score);
}
OnlineAuc group = map.get(groupKey);
if (group == null) {
group = new GlobalOnlineAuc();
if (policy != null) {
group.setPolicy(policy);
}
if (windowSize > 0) {
group.setWindowSize(windowSize);
}
map.put(groupKey, group);
}
return group.addSample(category, score);
}
@Override
public double addSample(int category, double score) {
throw new UnsupportedOperationException("Can't add to " + this.getClass() + " without group key");
}
@Override
public double auc() {
double sum = 0;
for (OnlineAuc auc : map.values()) {
sum += auc.auc();
}
return sum / map.size();
}
@Override
public void setPolicy(GlobalOnlineAuc.ReplacementPolicy policy) {
this.policy = policy;
for (OnlineAuc auc : map.values()) {
auc.setPolicy(policy);
}
}
@Override
public void setWindowSize(int windowSize) {
this.windowSize = windowSize;
for (OnlineAuc auc : map.values()) {
auc.setWindowSize(windowSize);
}
}
@Override
public void write(DataOutput out) throws IOException {
out.writeInt(map.size());
for (Map.Entry entry : map.entrySet()) {
out.writeUTF(entry.getKey());
PolymorphicWritable.write(out, entry.getValue());
}
out.writeInt(policy.ordinal());
out.writeInt(windowSize);
}
@Override
public void readFields(DataInput in) throws IOException {
int n = in.readInt();
map.clear();
for (int i = 0; i < n; i++) {
String key = in.readUTF();
map.put(key, PolymorphicWritable.read(in, OnlineAuc.class));
}
policy = GlobalOnlineAuc.ReplacementPolicy.values()[in.readInt()];
windowSize = in.readInt();
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy