All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.classifier.sgd.SimpleCsvExamples Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.sgd;

import com.google.common.base.Joiner;
import com.google.common.base.Splitter;
import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import com.google.common.io.Files;
import org.apache.commons.io.Charsets;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.list.IntArrayList;
import org.apache.mahout.math.stats.OnlineSummarizer;
import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.BufferedReader;
import java.io.Closeable;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/**
 * Shows how different encoding choices can make big speed differences.
 * 

* Run with command line options --generate 1000000 test.csv to generate a million data lines in * test.csv. *

* Run with command line options --parser test.csv to time how long it takes to parse and encode * those million data points *

* Run with command line options --fast test.csv to time how long it takes to parse and encode those * million data points using byte-level parsing and direct value encoding. *

* This doesn't demonstrate text encoding which is subject to somewhat different tricks. The basic * idea of caching hash locations and byte level parsing still very much applies to text, however. */ public final class SimpleCsvExamples { public static final char SEPARATOR_CHAR = '\t'; private static final int FIELDS = 100; private static final Logger log = LoggerFactory.getLogger(SimpleCsvExamples.class); private SimpleCsvExamples() {} public static void main(String[] args) throws IOException { FeatureVectorEncoder[] encoder = new FeatureVectorEncoder[FIELDS]; for (int i = 0; i < FIELDS; i++) { encoder[i] = new ConstantValueEncoder("v" + 1); } OnlineSummarizer[] s = new OnlineSummarizer[FIELDS]; for (int i = 0; i < FIELDS; i++) { s[i] = new OnlineSummarizer(); } long t0 = System.currentTimeMillis(); Vector v = new DenseVector(1000); if ("--generate".equals(args[0])) { try (PrintWriter out = new PrintWriter(new OutputStreamWriter(new FileOutputStream(new File(args[2])), Charsets.UTF_8))) { int n = Integer.parseInt(args[1]); for (int i = 0; i < n; i++) { Line x = Line.generate(); out.println(x); } } } else if ("--parse".equals(args[0])) { try (BufferedReader in = Files.newReader(new File(args[1]), Charsets.UTF_8)){ String line = in.readLine(); while (line != null) { v.assign(0); Line x = new Line(line); for (int i = 0; i < FIELDS; i++) { s[i].add(x.getDouble(i)); encoder[i].addToVector(x.get(i), v); } line = in.readLine(); } } String separator = ""; for (int i = 0; i < FIELDS; i++) { System.out.printf("%s%.3f", separator, s[i].getMean()); separator = ","; } } else if ("--fast".equals(args[0])) { try (FastLineReader in = new FastLineReader(new FileInputStream(args[1]))){ FastLine line = in.read(); while (line != null) { v.assign(0); for (int i = 0; i < FIELDS; i++) { double z = line.getDouble(i); s[i].add(z); encoder[i].addToVector((byte[]) null, z, v); } line = in.read(); } } String separator = ""; for (int i = 0; i < FIELDS; i++) { System.out.printf("%s%.3f", separator, s[i].getMean()); separator = ","; } } System.out.printf("\nElapsed time = %.3f%n", (System.currentTimeMillis() - t0) / 1000.0); } private static final class Line { private static final Splitter ON_TABS = Splitter.on(SEPARATOR_CHAR).trimResults(); public static final Joiner WITH_COMMAS = Joiner.on(SEPARATOR_CHAR); public static final Random RAND = RandomUtils.getRandom(); private final List data; private Line(CharSequence line) { data = Lists.newArrayList(ON_TABS.split(line)); } private Line() { data = new ArrayList<>(); } public double getDouble(int field) { return Double.parseDouble(data.get(field)); } /** * Generate a random line with 20 fields each with integer values. * * @return A new line with data. */ public static Line generate() { Line r = new Line(); for (int i = 0; i < FIELDS; i++) { double mean = ((i + 1) * 257) % 50 + 1; r.data.add(Integer.toString(randomValue(mean))); } return r; } /** * Returns a random exponentially distributed integer with a particular mean value. This is * just a way to create more small numbers than big numbers. * * @param mean mean of the distribution * @return random exponentially distributed integer with the specific mean */ private static int randomValue(double mean) { return (int) (-mean * Math.log1p(-RAND.nextDouble())); } @Override public String toString() { return WITH_COMMAS.join(data); } public String get(int field) { return data.get(field); } } private static final class FastLine { private final ByteBuffer base; private final IntArrayList start = new IntArrayList(); private final IntArrayList length = new IntArrayList(); private FastLine(ByteBuffer base) { this.base = base; } public static FastLine read(ByteBuffer buf) { FastLine r = new FastLine(buf); r.start.add(buf.position()); int offset = buf.position(); while (offset < buf.limit()) { int ch = buf.get(); offset = buf.position(); switch (ch) { case '\n': r.length.add(offset - r.start.get(r.length.size()) - 1); return r; case SEPARATOR_CHAR: r.length.add(offset - r.start.get(r.length.size()) - 1); r.start.add(offset); break; default: // nothing to do for now } } throw new IllegalArgumentException("Not enough bytes in buffer"); } public double getDouble(int field) { int offset = start.get(field); int size = length.get(field); switch (size) { case 1: return base.get(offset) - '0'; case 2: return (base.get(offset) - '0') * 10 + base.get(offset + 1) - '0'; default: double r = 0; for (int i = 0; i < size; i++) { r = 10 * r + base.get(offset + i) - '0'; } return r; } } } private static final class FastLineReader implements Closeable { private final InputStream in; private final ByteBuffer buf = ByteBuffer.allocate(100000); private FastLineReader(InputStream in) throws IOException { this.in = in; buf.limit(0); fillBuffer(); } public FastLine read() throws IOException { fillBuffer(); if (buf.remaining() > 0) { return FastLine.read(buf); } else { return null; } } private void fillBuffer() throws IOException { if (buf.remaining() < 10000) { buf.compact(); int n = in.read(buf.array(), buf.position(), buf.remaining()); if (n == -1) { buf.flip(); } else { buf.limit(buf.position() + n); buf.position(0); } } } @Override public void close() { try { Closeables.close(in, true); } catch (IOException e) { log.error(e.getMessage(), e); } } } }





© 2015 - 2025 Weber Informatics LLC | Privacy Policy