All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.math.hadoop.similarity.cooccurrence.Vectors Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.mahout.math.hadoop.similarity.cooccurrence;

import java.io.DataInput;
import java.io.IOException;
import java.util.Iterator;

import com.google.common.base.Preconditions;
import com.google.common.io.Closeables;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.common.iterator.FixedSizeSamplingIterator;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Varint;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.map.OpenIntIntHashMap;

public final class Vectors {

  private Vectors() {}

  public static Vector maybeSample(Vector original, int sampleSize) {
    if (original.getNumNondefaultElements() <= sampleSize) {
      return original;
    }
    Vector sample = new RandomAccessSparseVector(original.size(), sampleSize);
    Iterator sampledElements =
        new FixedSizeSamplingIterator<>(sampleSize, original.nonZeroes().iterator());
    while (sampledElements.hasNext()) {
      Element elem = sampledElements.next();
      sample.setQuick(elem.index(), elem.get());
    }
    return sample;
  }

  public static Vector topKElements(int k, Vector original) {
    if (original.getNumNondefaultElements() <= k) {
      return original;
    }

    TopElementsQueue topKQueue = new TopElementsQueue(k);
    for (Element nonZeroElement : original.nonZeroes()) {
      MutableElement top = topKQueue.top();
      double candidateValue = nonZeroElement.get();
      if (candidateValue > top.get()) {
        top.setIndex(nonZeroElement.index());
        top.set(candidateValue);
        topKQueue.updateTop();
      }
    }

    Vector topKSimilarities = new RandomAccessSparseVector(original.size(), k);
    for (Vector.Element topKSimilarity : topKQueue.getTopElements()) {
      topKSimilarities.setQuick(topKSimilarity.index(), topKSimilarity.get());
    }
    return topKSimilarities;
  }

  public static Vector merge(Iterable partialVectors) {
    Iterator vectors = partialVectors.iterator();
    Vector accumulator = vectors.next().get();
    while (vectors.hasNext()) {
      VectorWritable v = vectors.next();
      if (v != null) {
        for (Element nonZeroElement : v.get().nonZeroes()) {
          accumulator.setQuick(nonZeroElement.index(), nonZeroElement.get());
        }
      }
    }
    return accumulator;
  }

  public static Vector sum(Iterator vectors) {
    Vector sum = vectors.next().get();
    while (vectors.hasNext()) {
      sum.assign(vectors.next().get(), Functions.PLUS);
    }
    return sum;
  }

  static class TemporaryElement implements Vector.Element {

    private final int index;
    private double value;

    TemporaryElement(int index, double value) {
      this.index = index;
      this.value = value;
    }

    TemporaryElement(Vector.Element toClone) {
      this(toClone.index(), toClone.get());
    }

    @Override
    public double get() {
      return value;
    }

    @Override
    public int index() {
      return index;
    }

    @Override
    public void set(double value) {
      this.value = value;
    }
  }

  public static Vector.Element[] toArray(VectorWritable vectorWritable) {
    Vector.Element[] elements = new Vector.Element[vectorWritable.get().getNumNondefaultElements()];
    int k = 0;
    for (Element nonZeroElement : vectorWritable.get().nonZeroes()) {
      elements[k++] = new TemporaryElement(nonZeroElement.index(), nonZeroElement.get());
    }
    return elements;
  }

  public static void write(Vector vector, Path path, Configuration conf) throws IOException {
    write(vector, path, conf, false);
  }

  public static void write(Vector vector, Path path, Configuration conf, boolean laxPrecision) throws IOException {
    FileSystem fs = FileSystem.get(path.toUri(), conf);
    FSDataOutputStream out = fs.create(path);
    try {
      VectorWritable vectorWritable = new VectorWritable(vector);
      vectorWritable.setWritesLaxPrecision(laxPrecision);
      vectorWritable.write(out);
    } finally {
      Closeables.close(out, false);
    }
  }

  public static OpenIntIntHashMap readAsIntMap(Path path, Configuration conf) throws IOException {
    FileSystem fs = FileSystem.get(path.toUri(), conf);
    FSDataInputStream in = fs.open(path);
    try {
      return readAsIntMap(in);
    } finally {
      Closeables.close(in, true);
    }
  }

  /* ugly optimization for loading sparse vectors containing ints only */
  private static OpenIntIntHashMap readAsIntMap(DataInput in) throws IOException {
    int flags = in.readByte();
    Preconditions.checkArgument(flags >> VectorWritable.NUM_FLAGS == 0,
                                "Unknown flags set: %d", Integer.toString(flags, 2));
    boolean dense = (flags & VectorWritable.FLAG_DENSE) != 0;
    boolean sequential = (flags & VectorWritable.FLAG_SEQUENTIAL) != 0;
    boolean laxPrecision = (flags & VectorWritable.FLAG_LAX_PRECISION) != 0;
    Preconditions.checkState(!dense && !sequential, "Only for reading sparse vectors!");

    Varint.readUnsignedVarInt(in);

    OpenIntIntHashMap values = new OpenIntIntHashMap();
    int numNonDefaultElements = Varint.readUnsignedVarInt(in);
    for (int i = 0; i < numNonDefaultElements; i++) {
      int index = Varint.readUnsignedVarInt(in);
      double value = laxPrecision ? in.readFloat() : in.readDouble();
      values.put(index, (int) value);
    }
    return values;
  }

  public static Vector read(Path path, Configuration conf) throws IOException {
    FileSystem fs = FileSystem.get(path.toUri(), conf);
    FSDataInputStream in = fs.open(path);
    try {
      return VectorWritable.readVector(in);
    } finally {
      Closeables.close(in, true);
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy