![JAR search and dependency download from the Maven repository](/logo.png)
org.apache.mahout.classifier.sgd.SGDHelper Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mahout-examples Show documentation
Show all versions of mahout-examples Show documentation
Scalable machine learning library examples
The newest version!
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier.sgd;
import com.google.common.collect.Multiset;
import org.apache.mahout.classifier.NewsgroupHelper;
import org.apache.mahout.ep.State;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.vectorizer.encoders.Dictionary;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
public final class SGDHelper {
private static final String[] LEAK_LABELS = {"none", "month-year", "day-month-year"};
private SGDHelper() {
}
public static void dissect(int leakType,
Dictionary dictionary,
AdaptiveLogisticRegression learningAlgorithm,
Iterable files, Multiset overallCounts) throws IOException {
CrossFoldLearner model = learningAlgorithm.getBest().getPayload().getLearner();
model.close();
Map> traceDictionary = new TreeMap<>();
ModelDissector md = new ModelDissector();
NewsgroupHelper helper = new NewsgroupHelper();
helper.getEncoder().setTraceDictionary(traceDictionary);
helper.getBias().setTraceDictionary(traceDictionary);
for (File file : permute(files, helper.getRandom()).subList(0, 500)) {
String ng = file.getParentFile().getName();
int actual = dictionary.intern(ng);
traceDictionary.clear();
Vector v = helper.encodeFeatureVector(file, actual, leakType, overallCounts);
md.update(v, traceDictionary, model);
}
List ngNames = new ArrayList<>(dictionary.values());
List weights = md.summary(100);
System.out.println("============");
System.out.println("Model Dissection");
for (ModelDissector.Weight w : weights) {
System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s%n",
w.getFeature(), w.getWeight(), ngNames.get(w.getMaxImpact() + 1),
w.getCategory(1), w.getWeight(1), w.getCategory(2), w.getWeight(2));
}
}
public static List permute(Iterable files, Random rand) {
List r = new ArrayList<>();
for (File file : files) {
int i = rand.nextInt(r.size() + 1);
if (i == r.size()) {
r.add(file);
} else {
r.add(r.get(i));
r.set(i, file);
}
}
return r;
}
static void analyzeState(SGDInfo info, int leakType, int k, State best) throws IOException {
int bump = info.getBumps()[(int) Math.floor(info.getStep()) % info.getBumps().length];
int scale = (int) Math.pow(10, Math.floor(info.getStep() / info.getBumps().length));
double maxBeta;
double nonZeros;
double positive;
double norm;
double lambda = 0;
double mu = 0;
if (best != null) {
CrossFoldLearner state = best.getPayload().getLearner();
info.setAverageCorrect(state.percentCorrect());
info.setAverageLL(state.logLikelihood());
OnlineLogisticRegression model = state.getModels().get(0);
// finish off pending regularization
model.close();
Matrix beta = model.getBeta();
maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
@Override
public double apply(double v) {
return Math.abs(v) > 1.0e-6 ? 1 : 0;
}
});
positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
@Override
public double apply(double v) {
return v > 0 ? 1 : 0;
}
});
norm = beta.aggregate(Functions.PLUS, Functions.ABS);
lambda = best.getMappedParams()[0];
mu = best.getMappedParams()[1];
} else {
maxBeta = 0;
nonZeros = 0;
positive = 0;
norm = 0;
}
if (k % (bump * scale) == 0) {
if (best != null) {
File modelFile = new File(System.getProperty("java.io.tmpdir"), "news-group-" + k + ".model");
ModelSerializer.writeBinary(modelFile.getAbsolutePath(), best.getPayload().getLearner().getModels().get(0));
}
info.setStep(info.getStep() + 0.25);
System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu);
System.out.printf("%d\t%.3f\t%.2f\t%s%n",
k, info.getAverageLL(), info.getAverageCorrect() * 100, LEAK_LABELS[leakType % 3]);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy