All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.neural.VectorMap Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.neural;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.math.ArrayMath;

import java.io.*;
import java.util.HashMap;
import java.util.Map;
import java.util.zip.GZIPOutputStream;

/**
 * A serializer for reading / writing word vectors.
 * This is used to read word2vec in hcoref, and is primarily here
 * for its efficient serialization / deserialization protocol, which
 * saves/loads the vectors as 16 bit floats.
 *
 * @author Gabor Angeli
 */
public class VectorMap extends HashMap{

  /**
   * The integer type (i.e., number of bits per integer).
   */
  private enum itype {
    INT8,
    INT16,
    INT32;

    /**
     * Get the minimum integer type that will fit this number.
     */
    static itype getType(int num) {
      itype t = itype.INT32;
      if (num < Short.MAX_VALUE) {
        t = itype.INT16;
      }
      if (num < Byte.MAX_VALUE) {
        t = itype.INT8;
      }
      return t;
    }

    /**
     * Read an integer of this type from the given input stream
     */
    public int read(DataInputStream in) throws IOException {
      switch (this) {
        case INT8:
          return in.readByte();
        case INT16:
          return in.readShort();
        case INT32:
          return in.readInt();
        default:
          throw new RuntimeException("Unknown itype: " + this);
      }
    }

    /**
     * Write an integer of this type to the given output stream
     */
    public void write(DataOutputStream out, int value) throws IOException {
      switch (this) {
        case INT8:
          out.writeByte(value);
          break;
        case INT16:
          out.writeShort(value);
          break;
        case INT32:
          out.writeInt(value);
          break;
        default:
          throw new RuntimeException("Unknown itype: " + this);
      }
    }

  }

  /**
   * Create an empty word vector storage.
   */
  public VectorMap() {
    super(1024);
  }

  /**
   * Initialize word vectors from a given map.
   * @param vectors The word vectors as a simple map.
   */
  public VectorMap(Map vectors) {
    super(vectors);
  }


  /**
   * Write the word vectors to a file.
   *
   * @param file The file to write to.
   * @throws IOException Thrown if the file could not be written to.
   */
  public void serialize(String file) throws IOException {
    try (OutputStream output = new BufferedOutputStream(new FileOutputStream(new File(file)))) {
      if (file.endsWith(".gz")) {
        try (GZIPOutputStream gzip = new GZIPOutputStream(output)) {
          serialize(gzip);
        }
      } else {
        serialize(output);
      }
    }
  }

  /**
   * Write the word vectors to an output stream. The stream is not closed on finishing
   * the function.
   *
   * @param out The stream to write to.
   * @throws IOException Thrown if the stream could not be written to.
   */
  public void serialize(OutputStream out) throws IOException {
    DataOutputStream dataOut = new DataOutputStream(out);

    // Write some length statistics
    int maxKeyLength = 0;
    int vectorLength = 0;
    for (Entry entry : this.entrySet()) {
      maxKeyLength = Math.max(entry.getKey().getBytes().length, maxKeyLength);
      vectorLength = entry.getValue().length;
    }
    itype keyIntType = itype.getType(maxKeyLength);
    // Write the key length
    dataOut.writeInt(maxKeyLength);
    // Write the vector dim
    dataOut.writeInt(vectorLength);


    // Write the size of the dataset
    dataOut.writeInt(this.size());

    for (Map.Entry entry : this.entrySet()) {
      // Write the length of the key
      byte[] key = entry.getKey().getBytes();
      keyIntType.write(dataOut, key.length);
      dataOut.write(key);
      // Write the vector
      for (float v : entry.getValue()) {
        dataOut.writeShort(fromFloat(v));
      }
    }
  }


  /**
   * Read word vectors from a file or classpath or url.
   *
   * @param file The file to read from.
   * @return The vectors in the file.
   * @throws IOException Thrown if we could not read from the resource
   */
  public static VectorMap deserialize(String file) throws IOException {
    try (InputStream input = IOUtils.getInputStreamFromURLOrClasspathOrFileSystem(file)) {
      return deserialize(input);
    }
  }

  /**
   * Read word vectors from an input stream. The stream is not closed on finishing the function.
   *
   * @param in The stream to read from. This is not closed.
   * @return The word vectors encoded on the stream.
   * @throws IOException Thrown if we could not read from the stream.
   */
  public static VectorMap deserialize(InputStream in) throws IOException {
    DataInputStream dataIn = new DataInputStream(in);

    // Read the max key length
    itype keyIntType = itype.getType(dataIn.readInt());
    // Read the vector dimensionality
    int dim = dataIn.readInt();
    // Read the size of the dataset
    int size = dataIn.readInt();

    // Read the vectors
    VectorMap vectors = new VectorMap();
    for (int i = 0; i < size; ++i) {
      // Read the key
      int strlen = keyIntType.read(dataIn);
      byte[] buffer = new byte[strlen];
      if (dataIn.read(buffer, 0, strlen) != strlen) {
        throw new IOException("Could not read string buffer fully!");
      }
      String key = new String(buffer);
      // Read the vector
      float[] vector = new float[dim];
      for (int k = 0; k < vector.length; ++k) {
        vector[k] = toFloat(dataIn.readShort());
      }
      // Add the key/value
      vectors.put(key, vector);
    }
    return vectors;
  }


  /**
   * Read the Word2Vec word vector flat txt file.
   *
   * @param file The word2vec text file.
   * @return The word vectors in the file.
   */
  public static VectorMap readWord2Vec(String file) {
    VectorMap vectors = new VectorMap();
    int dim = -1;
    for(String line : IOUtils.readLines(file)){
      String[] split = line.toLowerCase().split("\\s+");
      if(split.length < 100) continue;
      float[] vector = new float[split.length-1];
      if (dim == -1) {
        dim = vector.length;
      }
      assert dim == vector.length;
      for(int i=1; i < split.length ; i++) {
        vector[i-1] = Float.parseFloat(split[i]);
      }
      ArrayMath.L2normalize(vector);
      vectors.put(split[0], vector);
    }

    return vectors;
  }

  @SuppressWarnings("unchecked")
  @Override
  public boolean equals(Object other) {
    if (other instanceof Map) {
      try {
        Map otherMap = (Map) other;
        // Key sets have the same size
        if (this.keySet().size() != otherMap.keySet().size()) {
          return false;
        }
        // Entries are the same
        for (Entry entry : this.entrySet()) {
          float[] otherValue = otherMap.get(entry.getKey());
          // Null checks
          if (otherValue == null && entry.getValue() != null) {
            return false;
          }
          if (otherValue != null && entry.getValue() == null) {
            return false;
          }
          // Entries are the same
          //noinspection ConstantConditions
          if (entry.getValue() != null && otherValue != null) {
            // Vectors are the same length
            if (entry.getValue().length != otherValue.length) {
              return false;
            }
            // Vectors are the same value
            for (int i = 0; i < otherValue.length; ++i) {
              if (!sameFloat(entry.getValue()[i], otherValue[i])) {
                return false;
              }
            }
          }
        }
        return true;
      } catch (ClassCastException e) {
        e.printStackTrace();
        return false;
      }
    } else {
      return false;
    }
  }

  @Override
  public int hashCode() {
    return keySet().hashCode();
  }

  @Override
  public String toString() {
    return "VectorMap[" + this.size() + "]";
  }

  /**
   * The check to see if two floats are "close enough."
   */
  private static boolean sameFloat(float a, float b) {
    float absDiff = Math.abs(a - b);
    float absA = Math.abs(a);
    float absB = Math.abs(b);
    return absDiff < 1e-10 ||
           absDiff < Math.max(absA, absB) / 100.0f ||
           (absA < 1e-5 && absB < 1e-5);
  }

  /**
   * From  http://stackoverflow.com/questions/6162651/half-precision-floating-point-in-java
   */
  private static float toFloat( short hbits ) {
    int mant = hbits & 0x03ff;            // 10 bits mantissa
    int exp =  hbits & 0x7c00;            // 5 bits exponent
    if( exp == 0x7c00 )                   // NaN/Inf
      exp = 0x3fc00;                    // -> NaN/Inf
    else if( exp != 0 )                   // normalized value
    {
      exp += 0x1c000;                   // exp - 15 + 127
      if( mant == 0 && exp > 0x1c400 )  // smooth transition
        return Float.intBitsToFloat( ( hbits & 0x8000 ) << 16
            | exp << 13 | 0x3ff );
    }
    else if( mant != 0 )                  // && exp==0 -> subnormal
    {
      exp = 0x1c400;                    // make it normal
      do {
        mant <<= 1;                   // mantissa * 2
        exp -= 0x400;                 // decrease exp by 1
      } while( ( mant & 0x400 ) == 0 ); // while not normal
      mant &= 0x3ff;                    // discard subnormal bit
    }                                     // else +/-0 -> +/-0
    return Float.intBitsToFloat(          // combine all parts
        ( hbits & 0x8000 ) << 16          // sign  << ( 31 - 15 )
            | ( exp | mant ) << 13 );         // value << ( 23 - 10 )
  }

  /**
   * From  http://stackoverflow.com/questions/6162651/half-precision-floating-point-in-java
   */
  private static short fromFloat( float fval ) {
    int fbits = Float.floatToIntBits( fval );
    int sign = fbits >>> 16 & 0x8000;          // sign only
    int val = ( fbits & 0x7fffffff ) + 0x1000; // rounded value

    if( val >= 0x47800000 )               // might be or become NaN/Inf
    {                                     // avoid Inf due to rounding
      if( ( fbits & 0x7fffffff ) >= 0x47800000 )
      {                                 // is or must become NaN/Inf
        if( val < 0x7f800000 )        // was value but too large
          return (short) (sign | 0x7c00);     // make it +/-Inf
        return (short) (sign | 0x7c00 |        // remains +/-Inf or NaN
            ( fbits & 0x007fffff ) >>> 13); // keep NaN (and Inf) bits
      }
      return (short) (sign | 0x7bff);             // unrounded not quite Inf
    }
    if( val >= 0x38800000 )               // remains normalized value
      return (short) (sign | val - 0x38000000 >>> 13); // exp - 127 + 15
    if( val < 0x33000000 )                // too small for subnormal
      return (short) sign;                      // becomes +/-0
    val = ( fbits & 0x7fffffff ) >>> 23;  // tmp exp for subnormal calc
    return (short) (sign | ( ( fbits & 0x7fffff | 0x800000 ) // add subnormal bit
        + ( 0x800000 >>> val - 102 )     // round depending on cut off
        >>> 126 - val ));   // div by 2^(1-(exp-127+15)) and >> 13 | exp=0
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy