All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.joshua.util.encoding.Analyzer Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *  http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.joshua.util.encoding;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.TreeMap;

import org.apache.joshua.util.io.LineReader;

public class Analyzer {

  private final TreeMap histogram;
  private int total;

  public Analyzer() {
    histogram = new TreeMap<>();
    initialize();
  }

  public void initialize() {
    histogram.clear();
    // TODO: drop zero bucket; we won't encode zero-valued features anyway.
    histogram.put(0.0f, 0);
    total = 0;
  }

  public void add(float key) {
    if (histogram.containsKey(key))
      histogram.put(key, histogram.get(key) + 1);
    else
      histogram.put(key, 1);
    total++;
  }

  public float[] quantize(int num_bits) {
    float[] buckets = new float[1 << num_bits];

    // We make sure that 0.0f always has its own bucket, so the bucket
    // size is determined excluding the zero values.
    int size = (total - histogram.get(0.0f)) / (buckets.length - 1);
    buckets[0] = 0.0f;

    int old_size = -1;
    while (old_size != size) {
      int sum = 0;
      int count = buckets.length - 1;
      for (float key : histogram.keySet()) {
        int entry_count = histogram.get(key);
        if (entry_count < size && key != 0)
          sum += entry_count;
        else
          count--;
      }
      old_size = size;
      size = sum / count;
    }

    float last_key = Float.MAX_VALUE;

    int index = 1;
    int count = 0;
    float sum = 0.0f;

    int value;
    for (float key : histogram.keySet()) {
      value = histogram.get(key);
      // Special bucket termination cases: zero boundary and histogram spikes.
      if (key == 0 || (last_key < 0 && key > 0) || (value >= size)) {
        // If the count is not 0, i.e. there were negative values, we should
        // not bucket them with the positive ones. Close out the bucket now.
        if (count != 0 && index < buckets.length - 2) {
          buckets[index++] = sum / count;
          count = 0;
          sum = 0;
        }
        if (key == 0)
          continue;
      }
      count += value;
      sum += key * value;
      // Check if the bucket is full.
      if (count >= size && index < buckets.length - 2) {
        buckets[index++] = sum / count;
        count = 0;
        sum = 0;
      }
      last_key = key;
    }
    if (count > 0 && index < buckets.length - 1)
      buckets[index++] = sum / count;

    float[] shortened = new float[index];
    System.arraycopy(buckets, 0, shortened, 0, shortened.length);
    return shortened;
  }

  public boolean isBoolean() {
    for (float value : histogram.keySet())
      if (value != 0 && value != 1)
        return false;
    return true;
  }

  public boolean isByte() {
    for (float value : histogram.keySet())
      if (Math.ceil(value) != value || value < Byte.MIN_VALUE || value > Byte.MAX_VALUE)
        return false;
    return true;
  }

  public boolean isShort() {
    for (float value : histogram.keySet())
      if (Math.ceil(value) != value || value < Short.MIN_VALUE || value > Short.MAX_VALUE)
        return false;
    return true;
  }

  public boolean isChar() {
    for (float value : histogram.keySet())
      if (Math.ceil(value) != value || value < Character.MIN_VALUE || value > Character.MAX_VALUE)
        return false;
    return true;
  }

  public boolean isInt() {
    for (float value : histogram.keySet())
      if (Math.ceil(value) != value)
        return false;
    return true;
  }

  public boolean is8Bit() {
    return (histogram.keySet().size() <= 256);
  }

  public FloatEncoder inferUncompressedType() {
    if (isBoolean())
      return PrimitiveFloatEncoder.BOOLEAN;
    if (isByte())
      return PrimitiveFloatEncoder.BYTE;
    if (is8Bit())
      return new EightBitQuantizer(this.quantize(8));
    if (isChar())
      return PrimitiveFloatEncoder.CHAR;
    if (isShort())
      return PrimitiveFloatEncoder.SHORT;
    if (isInt())
      return PrimitiveFloatEncoder.INT;
    return PrimitiveFloatEncoder.FLOAT;
  }

  public FloatEncoder inferType(int bits) {
    if (isBoolean())
      return PrimitiveFloatEncoder.BOOLEAN;
    if (isByte())
      return PrimitiveFloatEncoder.BYTE;
    if (bits == 8 || is8Bit())
      return new EightBitQuantizer(this.quantize(8));
    // TODO: Could add sub-8-bit encoding here (or larger).
    if (isChar())
      return PrimitiveFloatEncoder.CHAR;
    if (isShort())
      return PrimitiveFloatEncoder.SHORT;
    if (isInt())
      return PrimitiveFloatEncoder.INT;
    return PrimitiveFloatEncoder.FLOAT;
  }

  public String toString(String label) {
    StringBuilder sb = new StringBuilder();
    for (float val : histogram.keySet())
      sb.append(label).append("\t").append(String.format("%.5f", val)).append("\t")
          .append(histogram.get(val)).append("\n");
    return sb.toString();
  }

  public static void main(String[] args) throws IOException {
    try (LineReader reader = new LineReader(args[0])) {
      ArrayList s = new ArrayList<>();

      System.out.println("Initialized.");
      while (reader.hasNext())
        s.add(Float.parseFloat(reader.next().trim()));
      System.out.println("Data read.");
      int n = s.size();
      byte[] c = new byte[n];
      ByteBuffer b = ByteBuffer.wrap(c);
      Analyzer q = new Analyzer();

      q.initialize();
      for (Float value1 : s)
        q.add(value1);
      EightBitQuantizer eb = new EightBitQuantizer(q.quantize(8));
      System.out.println("Quantizer learned.");

      for (Float value : s)
        eb.write(b, value);
      b.rewind();
      System.out.println("Quantization complete.");

      float avg_error = 0;
      float error = 0;
      int count = 0;
      for (int i = -4; i < n - 4; i++) {
        float coded = eb.read(b, i);
        if (s.get(i + 4) != 0) {
          error = Math.abs(s.get(i + 4) - coded);
          avg_error += error;
          count++;
        }
      }
      avg_error /= count;
      System.out.println("Evaluation complete.");

      System.out.println("Average quanitization error over " + n + " samples is: " + avg_error);
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy