org.apache.mahout.benchmark.VectorBenchmarks Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of mahout-utils Show documentation
Show all versions of mahout-utils Show documentation
Utilities for preparing content into formats for Mahout.
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.benchmark;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Map.Entry;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.lang.StringUtils;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.Summarizable;
import org.apache.mahout.common.TimingStatistics;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.distance.CosineDistanceMeasure;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
import org.apache.mahout.common.distance.TanimotoDistanceMeasure;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class VectorBenchmarks implements Summarizable {
private static final Logger log = LoggerFactory.getLogger(VectorBenchmarks.class);
private final Vector[][] vectors;
private final List randomVectors = new ArrayList();
private final List randomVectorIndices = new ArrayList();
private final List randomVectorValues = new ArrayList();
private final int cardinality;
private final int sparsity;
private final int numVectors;
private final int loop;
private final int opsPerUnit;
private final Map implType = new HashMap();
private final Map> statsMap = new HashMap>();
public VectorBenchmarks(int cardinality, int sparsity, int numVectors, int loop, int opsPerUnit) {
Random r = RandomUtils.getRandom();
this.cardinality = cardinality;
this.sparsity = sparsity;
this.numVectors = numVectors;
this.loop = loop;
this.opsPerUnit = opsPerUnit;
for (int i = 0; i < numVectors; i++) {
Vector v = new SequentialAccessSparseVector(cardinality, sparsity); // sparsity!
BitSet featureSpace = new BitSet(cardinality);
int[] indexes = new int[sparsity];
double[] values = new double[sparsity];
int j = 0;
while (j < sparsity) {
double value = r.nextGaussian();
int index = r.nextInt(cardinality);
if (featureSpace.get(index) == false) {
featureSpace.set(index);
indexes[j] = index;
values[j++] = value;
v.set(index, value);
}
}
randomVectorIndices.add(indexes);
randomVectorValues.add(values);
randomVectors.add(v);
}
vectors = new Vector[3][numVectors];
}
private void printStats(TimingStatistics stats, String benchmarkName, String implName, String content) {
printStats(stats, benchmarkName, implName, content, 1);
}
private void printStats(TimingStatistics stats, String benchmarkName, String implName) {
printStats(stats, benchmarkName, implName, "", 1);
}
private void printStats(TimingStatistics stats,
String benchmarkName,
String implName,
String content,
int multiplier) {
float speed = multiplier * loop * numVectors * sparsity * 1000.0f * 12 / stats.getSumTime();
float opsPerSec = loop * numVectors * 1000000000.0f / stats.getSumTime();
log.info("{} {} \n{} {} \nSpeed: {} UnitsProcessed/sec {} MBytes/sec ",
new Object[] {benchmarkName, implName, content, stats.toString(), opsPerSec, speed});
String info = stats.toString().replaceAll("\n", "\t") + "\tSpeed = " + opsPerSec + " /sec\tRate = "
+ speed + " MB/s";
if (implType.containsKey(implName) == false) {
implType.put(implName, implType.size());
}
int implId = implType.get(implName);
if (statsMap.containsKey(benchmarkName) == false) {
statsMap.put(benchmarkName, new ArrayList());
}
List implStats = statsMap.get(benchmarkName);
while (implStats.size() < implId + 1) {
implStats.add(new String[] {});
}
implStats.set(implId, info.split("\t"));
}
public void createBenchmark() {
TimingStatistics stats = new TimingStatistics();
for (int l = 0; l < loop; l++) {
for (int i = 0; i < numVectors; i++) {
TimingStatistics.Call call = stats.newCall();
vectors[0][i] = new DenseVector(randomVectors.get(i));
call.end();
}
}
printStats(stats, "Create (copy)", "DenseVector");
stats = new TimingStatistics();
for (int l = 0; l < loop; l++) {
for (int i = 0; i < numVectors; i++) {
TimingStatistics.Call call = stats.newCall();
vectors[1][i] = new RandomAccessSparseVector(randomVectors.get(i));
call.end();
}
}
printStats(stats, "Create (copy)", "RandSparseVector");
stats = new TimingStatistics();
for (int l = 0; l < loop; l++) {
for (int i = 0; i < numVectors; i++) {
TimingStatistics.Call call = stats.newCall();
vectors[2][i] = new SequentialAccessSparseVector(randomVectors.get(i));
call.end();
}
}
printStats(stats, "Create (copy)", "SeqSparseVector");
}
private void buildVectorIncrementally(TimingStatistics stats, int randomIndex, Vector v, boolean useSetQuick) {
int[] indexes = randomVectorIndices.get(randomIndex);
double[] values = randomVectorValues.get(randomIndex);
List randomOrder = new ArrayList();
for(int i=0; i e : implType.entrySet()) {
if (e.getValue() == i) {
sb.append(StringUtils.rightPad(e.getKey(), pad).substring(0, pad));
break;
}
}
}
sb.append('\n');
List keys = new ArrayList(statsMap.keySet());
Collections.sort(keys);
for (String benchmarkName : keys) {
List implTokenizedStats = statsMap.get(benchmarkName);
int maxStats = 0;
for (String[] stat : implTokenizedStats) {
maxStats = Math.max(maxStats, stat.length);
}
for (int i = 0; i < maxStats; i++) {
boolean printedName = false;
for (String[] stats : implTokenizedStats) {
if (i == 0 && !printedName) {
sb.append(StringUtils.rightPad(benchmarkName, pad));
printedName = true;
} else if (!printedName) {
printedName = true;
sb.append(StringUtils.rightPad("", pad));
}
if (stats.length > i) {
sb.append(StringUtils.rightPad(stats[i], pad));
} else {
sb.append(StringUtils.rightPad("", pad));
}
}
sb.append('\n');
}
sb.append('\n');
}
return sb.toString();
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy