All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.benchmark.VectorBenchmarks Maven / Gradle / Ivy

There is a newer version: 0.5
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.benchmark;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Map.Entry;

import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.lang.StringUtils;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.Summarizable;
import org.apache.mahout.common.TimingStatistics;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.distance.CosineDistanceMeasure;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
import org.apache.mahout.common.distance.TanimotoDistanceMeasure;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VectorBenchmarks implements Summarizable {

  private static final Logger log = LoggerFactory.getLogger(VectorBenchmarks.class);

  private final Vector[][] vectors;
  private final List randomVectors = new ArrayList();
  private final List randomVectorIndices = new ArrayList();
  private final List randomVectorValues = new ArrayList();
  private final int cardinality;
  private final int sparsity;
  private final int numVectors;
  private final int loop;
  private final int opsPerUnit;
  private final Map implType = new HashMap();
  private final Map> statsMap = new HashMap>();
  
  public VectorBenchmarks(int cardinality, int sparsity, int numVectors, int loop, int opsPerUnit) {
    Random r = RandomUtils.getRandom();
    this.cardinality = cardinality;
    this.sparsity = sparsity;
    this.numVectors = numVectors;
    this.loop = loop;
    this.opsPerUnit = opsPerUnit;
    for (int i = 0; i < numVectors; i++) {
      Vector v = new SequentialAccessSparseVector(cardinality, sparsity); // sparsity!
      BitSet featureSpace = new BitSet(cardinality);
      int[] indexes = new int[sparsity];
      double[] values = new double[sparsity];
      int j = 0;
      while (j < sparsity) {
        double value = r.nextGaussian();
        int index = r.nextInt(cardinality);
        if (featureSpace.get(index) == false) {
          featureSpace.set(index);
          indexes[j] = index;
          values[j++] = value;
          v.set(index, value);
        }
      }
      randomVectorIndices.add(indexes);
      randomVectorValues.add(values);
      randomVectors.add(v);
    }
    vectors = new Vector[3][numVectors];
    
  }
  
  private void printStats(TimingStatistics stats, String benchmarkName, String implName, String content) {
    printStats(stats, benchmarkName, implName, content, 1);
  }
  
  private void printStats(TimingStatistics stats, String benchmarkName, String implName) {
    printStats(stats, benchmarkName, implName, "", 1);
  }
  
  private void printStats(TimingStatistics stats,
                          String benchmarkName,
                          String implName,
                          String content,
                          int multiplier) {
    float speed = multiplier * loop * numVectors * sparsity * 1000.0f * 12 / stats.getSumTime();
    float opsPerSec = loop * numVectors * 1000000000.0f / stats.getSumTime();
    log.info("{} {} \n{} {} \nSpeed: {} UnitsProcessed/sec {} MBytes/sec                                   ",
      new Object[] {benchmarkName, implName, content, stats.toString(), opsPerSec, speed});
    String info = stats.toString().replaceAll("\n", "\t") + "\tSpeed = " + opsPerSec + " /sec\tRate = "
                  + speed + " MB/s";
    if (implType.containsKey(implName) == false) {
      implType.put(implName, implType.size());
    }
    int implId = implType.get(implName);
    if (statsMap.containsKey(benchmarkName) == false) {
      statsMap.put(benchmarkName, new ArrayList());
    }
    List implStats = statsMap.get(benchmarkName);
    while (implStats.size() < implId + 1) {
      implStats.add(new String[] {});
    }
    implStats.set(implId, info.split("\t"));
  }
  
  public void createBenchmark() {
    TimingStatistics stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        vectors[0][i] = new DenseVector(randomVectors.get(i));
        call.end();
      }
    }
    printStats(stats, "Create (copy)", "DenseVector");
    
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        vectors[1][i] = new RandomAccessSparseVector(randomVectors.get(i));
        call.end();
      }
    }
    printStats(stats, "Create (copy)", "RandSparseVector");
    
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        vectors[2][i] = new SequentialAccessSparseVector(randomVectors.get(i));
        call.end();
      }
    }
    printStats(stats, "Create (copy)", "SeqSparseVector");
    
  }

  private void buildVectorIncrementally(TimingStatistics stats, int randomIndex, Vector v, boolean useSetQuick) {
    int[] indexes = randomVectorIndices.get(randomIndex);
    double[] values = randomVectorValues.get(randomIndex);
    List randomOrder = new ArrayList();
    for(int i=0; i e : implType.entrySet()) {
        if (e.getValue() == i) {
          sb.append(StringUtils.rightPad(e.getKey(), pad).substring(0, pad));
          break;
        }
      }
    }
    sb.append('\n');
    List keys = new ArrayList(statsMap.keySet());
    Collections.sort(keys);
    for (String benchmarkName : keys) {
      List implTokenizedStats = statsMap.get(benchmarkName);
      int maxStats = 0;
      for (String[] stat : implTokenizedStats) {
        maxStats = Math.max(maxStats, stat.length);
      }
      
      for (int i = 0; i < maxStats; i++) {
        boolean printedName = false;
        for (String[] stats : implTokenizedStats) {
          if (i == 0 && !printedName) {
            sb.append(StringUtils.rightPad(benchmarkName, pad));
            printedName = true;
          } else if (!printedName) {
            printedName = true;
            sb.append(StringUtils.rightPad("", pad));
          }
          if (stats.length > i) {
            sb.append(StringUtils.rightPad(stats[i], pad));
          } else {
            sb.append(StringUtils.rightPad("", pad));
          }

        }
        sb.append('\n');
      }
      sb.append('\n');
    }
    return sb.toString();
  }
  
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy