All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.mahout.math.VectorWritable Maven / Gradle / Ivy

The newest version!
/**
 * Licensed to the Apache Software Foundation (ASF) under one or more contributor license
 * agreements. See the NOTICE file distributed with this work for additional information regarding
 * copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance with the License. You may obtain a
 * copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software distributed under the License
 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
 * or implied. See the License for the specific language governing permissions and limitations under
 * the License.
 */

package org.apache.mahout.math;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Iterator;

import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.math.Vector.Element;

import com.google.common.base.Preconditions;

public final class VectorWritable extends Configured implements Writable {

  public static final int FLAG_DENSE = 0x01;
  public static final int FLAG_SEQUENTIAL = 0x02;
  public static final int FLAG_NAMED = 0x04;
  public static final int FLAG_LAX_PRECISION = 0x08;
  public static final int NUM_FLAGS = 4;

  private Vector vector;
  private boolean writesLaxPrecision;

  public VectorWritable() {}

  public VectorWritable(boolean writesLaxPrecision) {
    setWritesLaxPrecision(writesLaxPrecision);
  }

  public VectorWritable(Vector vector) {
    this.vector = vector;
  }

  public VectorWritable(Vector vector, boolean writesLaxPrecision) {
    this(vector);
    setWritesLaxPrecision(writesLaxPrecision);
  }

  /**
   * @return {@link org.apache.mahout.math.Vector} that this is to write, or has
   *  just read
   */
  public Vector get() {
    return vector;
  }

  public void set(Vector vector) {
    this.vector = vector;
  }

  /**
   * @return true if this is allowed to encode {@link org.apache.mahout.math.Vector}
   *  values using fewer bytes, possibly losing precision. In particular this means
   *  that floating point values will be encoded as floats, not doubles.
   */
  public boolean isWritesLaxPrecision() {
    return writesLaxPrecision;
  }

  public void setWritesLaxPrecision(boolean writesLaxPrecision) {
    this.writesLaxPrecision = writesLaxPrecision;
  }

  @Override
  public void write(DataOutput out) throws IOException {
    writeVector(out, this.vector, this.writesLaxPrecision);
  }

  @Override
  public void readFields(DataInput in) throws IOException {
    int flags = in.readByte();
    int size = Varint.readUnsignedVarInt(in);
    readFields(in, (byte) flags, size);
  }

  private void readFields(DataInput in, byte flags, int size) throws IOException {

    Preconditions.checkArgument(flags >> NUM_FLAGS == 0, "Unknown flags set: %d", Integer.toString(flags, 2));
    boolean dense = (flags & FLAG_DENSE) != 0;
    boolean sequential = (flags & FLAG_SEQUENTIAL) != 0;
    boolean named = (flags & FLAG_NAMED) != 0;
    boolean laxPrecision = (flags & FLAG_LAX_PRECISION) != 0;

    Vector v;
    if (dense) {
      double[] values = new double[size];
      for (int i = 0; i < size; i++) {
        values[i] = laxPrecision ? in.readFloat() : in.readDouble();
      }
      v = new DenseVector(values);
    } else {
      int numNonDefaultElements = Varint.readUnsignedVarInt(in);
      v = sequential
          ? new SequentialAccessSparseVector(size, numNonDefaultElements)
          : new RandomAccessSparseVector(size, numNonDefaultElements);
      if (sequential) {
        int lastIndex = 0;
        for (int i = 0; i < numNonDefaultElements; i++) {
          int delta = Varint.readUnsignedVarInt(in);
          int index = lastIndex + delta;
          lastIndex = index;
          double value = laxPrecision ? in.readFloat() : in.readDouble();
          v.setQuick(index, value);
        }
      } else {
        for (int i = 0; i < numNonDefaultElements; i++) {
          int index = Varint.readUnsignedVarInt(in);
          double value = laxPrecision ? in.readFloat() : in.readDouble();
          v.setQuick(index, value);
        }
      }
    }
    if (named) {
      String name = in.readUTF();
      v = new NamedVector(v, name);
    }
    vector = v;
  }

  /** Write the vector to the output */
  public static void writeVector(DataOutput out, Vector vector) throws IOException {
    writeVector(out, vector, false);
  }

  public static byte flags(Vector vector, boolean laxPrecision) {
    boolean dense = vector.isDense();
    boolean sequential = vector.isSequentialAccess();
    boolean named = vector instanceof NamedVector;

    return (byte) ((dense ? FLAG_DENSE : 0)
            | (sequential ? FLAG_SEQUENTIAL : 0)
            | (named ? FLAG_NAMED : 0)
            | (laxPrecision ? FLAG_LAX_PRECISION : 0));
  }

  /** Write out type information and size of the vector */
  public static void writeVectorFlagsAndSize(DataOutput out, byte flags, int size) throws IOException {
    out.writeByte(flags);
    Varint.writeUnsignedVarInt(size, out);
  }

  public static void writeVector(DataOutput out, Vector vector, boolean laxPrecision) throws IOException {

    byte flags = flags(vector, laxPrecision);

    writeVectorFlagsAndSize(out, flags, vector.size());
    writeVectorContents(out, vector, flags);
  }

  /** Write out contents of the vector */
  public static void writeVectorContents(DataOutput out, Vector vector, byte flags) throws IOException {

    boolean dense = (flags & FLAG_DENSE) != 0;
    boolean sequential = (flags & FLAG_SEQUENTIAL) != 0;
    boolean named = (flags & FLAG_NAMED) != 0;
    boolean laxPrecision = (flags & FLAG_LAX_PRECISION) != 0;

    if (dense) {
      for (Element element : vector.all()) {
        if (laxPrecision) {
          out.writeFloat((float) element.get());
        } else {
          out.writeDouble(element.get());
        }
      }
    } else {
      Varint.writeUnsignedVarInt(vector.getNumNonZeroElements(), out);
      Iterator iter = vector.nonZeroes().iterator();
      if (sequential) {
        int lastIndex = 0;
        while (iter.hasNext()) {
          Element element = iter.next();
          if (element.get() == 0) {
            continue;
          }
          int thisIndex = element.index();
          // Delta-code indices:
          Varint.writeUnsignedVarInt(thisIndex - lastIndex, out);
          lastIndex = thisIndex;
          if (laxPrecision) {
            out.writeFloat((float) element.get());
          } else {
            out.writeDouble(element.get());
          }
        }
      } else {
        while (iter.hasNext()) {
          Element element = iter.next();
          if (element.get() == 0) {
            // TODO(robinanil): Fix the damn iterator for the zero element.
            continue;
          }
          Varint.writeUnsignedVarInt(element.index(), out);
          if (laxPrecision) {
            out.writeFloat((float) element.get());
          } else {
            out.writeDouble(element.get());
          }
        }
      }
    }
    if (named) {
      String name = ((NamedVector) vector).getName();
      out.writeUTF(name == null ? "" : name);
    }
  }

  public static Vector readVector(DataInput in) throws IOException {
    VectorWritable v = new VectorWritable();
    v.readFields(in);
    return v.get();
  }

  public static Vector readVector(DataInput in, byte vectorFlags, int size) throws IOException {
    VectorWritable v = new VectorWritable();
    v.readFields(in, vectorFlags, size);
    return v.get();
  }

  public static VectorWritable merge(Iterator vectors) {
    return new VectorWritable(mergeToVector(vectors));
  }

  public static Vector mergeToVector(Iterator vectors) {
    Vector accumulator = vectors.next().get();
    while (vectors.hasNext()) {
      VectorWritable v = vectors.next();
      if (v != null) {
        for (Element nonZeroElement : v.get().nonZeroes()) {
          accumulator.setQuick(nonZeroElement.index(), nonZeroElement.get());
        }
      }
    }
    return accumulator;
  }

  @Override
  public boolean equals(Object o) {
    return o instanceof VectorWritable && vector.equals(((VectorWritable) o).get());
  }

  @Override
  public int hashCode() {
    return vector.hashCode();
  }

  @Override
  public String toString() {
    return vector.toString();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy