All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat Maven / Gradle / Ivy

/**
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.classifier.df.mapreduce.inmem;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Random;

import com.google.common.base.Preconditions;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.InputFormat;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.mahout.classifier.df.mapreduce.Builder;
import org.apache.mahout.common.RandomUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Custom InputFormat that generates InputSplits given the desired number of trees.
* each input split contains a subset of the trees.
* The number of splits is equal to the number of requested splits */ public class InMemInputFormat extends InputFormat { private static final Logger log = LoggerFactory.getLogger(InMemInputSplit.class); private Random rng; private Long seed; private boolean isSingleSeed; /** * Used for DEBUG purposes only. if true and a seed is available, all the mappers use the same seed, thus * all the mapper should take the same time to build their trees. */ private static boolean isSingleSeed(Configuration conf) { return conf.getBoolean("debug.mahout.rf.single.seed", false); } @Override public RecordReader createRecordReader(InputSplit split, TaskAttemptContext context) throws IOException, InterruptedException { Preconditions.checkArgument(split instanceof InMemInputSplit); return new InMemRecordReader((InMemInputSplit) split); } @Override public List getSplits(JobContext context) throws IOException, InterruptedException { Configuration conf = context.getConfiguration(); int numSplits = conf.getInt("mapred.map.tasks", -1); return getSplits(conf, numSplits); } public List getSplits(Configuration conf, int numSplits) { int nbTrees = Builder.getNbTrees(conf); int splitSize = nbTrees / numSplits; seed = Builder.getRandomSeed(conf); isSingleSeed = isSingleSeed(conf); if (rng != null && seed != null) { log.warn("getSplits() was called more than once and the 'seed' is set, " + "this can lead to no-repeatable behavior"); } rng = seed == null || isSingleSeed ? null : RandomUtils.getRandom(seed); int id = 0; List splits = new ArrayList<>(numSplits); for (int index = 0; index < numSplits - 1; index++) { splits.add(new InMemInputSplit(id, splitSize, nextSeed())); id += splitSize; } // take care of the remainder splits.add(new InMemInputSplit(id, nbTrees - id, nextSeed())); return splits; } /** * @return the seed for the next InputSplit */ private Long nextSeed() { if (seed == null) { return null; } else if (isSingleSeed) { return seed; } else { return rng.nextLong(); } } public static class InMemRecordReader extends RecordReader { private final InMemInputSplit split; private int pos; private IntWritable key; private NullWritable value; public InMemRecordReader(InMemInputSplit split) { this.split = split; } @Override public float getProgress() throws IOException { return pos == 0 ? 0.0f : (float) (pos - 1) / split.nbTrees; } @Override public IntWritable getCurrentKey() throws IOException, InterruptedException { return key; } @Override public NullWritable getCurrentValue() throws IOException, InterruptedException { return value; } @Override public void initialize(InputSplit arg0, TaskAttemptContext arg1) throws IOException, InterruptedException { key = new IntWritable(); value = NullWritable.get(); } @Override public boolean nextKeyValue() throws IOException, InterruptedException { if (pos < split.nbTrees) { key.set(split.firstId + pos); pos++; return true; } else { return false; } } @Override public void close() throws IOException { } } /** * Custom InputSplit that indicates how many trees are built by each mapper */ public static class InMemInputSplit extends InputSplit implements Writable { private static final String[] NO_LOCATIONS = new String[0]; /** Id of the first tree of this split */ private int firstId; private int nbTrees; private Long seed; public InMemInputSplit() { } public InMemInputSplit(int firstId, int nbTrees, Long seed) { this.firstId = firstId; this.nbTrees = nbTrees; this.seed = seed; } /** * @return the Id of the first tree of this split */ public int getFirstId() { return firstId; } /** * @return the number of trees */ public int getNbTrees() { return nbTrees; } /** * @return the random seed or null if no seed is available */ public Long getSeed() { return seed; } @Override public long getLength() throws IOException { return nbTrees; } @Override public String[] getLocations() throws IOException { return NO_LOCATIONS; } @Override public boolean equals(Object obj) { if (this == obj) { return true; } if (!(obj instanceof InMemInputSplit)) { return false; } InMemInputSplit split = (InMemInputSplit) obj; if (firstId != split.firstId || nbTrees != split.nbTrees) { return false; } if (seed == null) { return split.seed == null; } else { return seed.equals(split.seed); } } @Override public int hashCode() { return firstId + nbTrees + (seed == null ? 0 : seed.intValue()); } @Override public String toString() { return String.format(Locale.ENGLISH, "[firstId:%d, nbTrees:%d, seed:%d]", firstId, nbTrees, seed); } @Override public void readFields(DataInput in) throws IOException { firstId = in.readInt(); nbTrees = in.readInt(); boolean isSeed = in.readBoolean(); seed = isSeed ? in.readLong() : null; } @Override public void write(DataOutput out) throws IOException { out.writeInt(firstId); out.writeInt(nbTrees); out.writeBoolean(seed != null); if (seed != null) { out.writeLong(seed); } } public static InMemInputSplit read(DataInput in) throws IOException { InMemInputSplit split = new InMemInputSplit(); split.readFields(in); return split; } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy