All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hadoop.hive.ql.udf.UDAFPercentile Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.hive.ql.udf;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.shims.ShimLoader;
import org.apache.hadoop.io.LongWritable;

/**
 * UDAF for calculating the percentile values.
 * There are several definitions of percentile, and we take the method recommended by
 * NIST.
 * @see 
 *      Percentile references
 */
@Description(name = "percentile",
    value = "_FUNC_(expr, pc) - Returns the percentile(s) of expr at pc (range: [0,1])."
      + "pc can be a double or double array")
public class UDAFPercentile extends UDAF {

  private static final Comparator COMPARATOR;

  static {
    COMPARATOR = ShimLoader.getHadoopShims().getLongComparator();
  }

  /**
   * A state class to store intermediate aggregation results.
   */
  public static class State {
    private Map counts;
    private List percentiles;
  }

  /**
   * A comparator to sort the entries in order.
   */
  public static class MyComparator implements Comparator> {
    @Override
    public int compare(Map.Entry o1,
        Map.Entry o2) {
      return COMPARATOR.compare(o1.getKey(), o2.getKey());
    }
  }

  /**
   * Increment the State object with o as the key, and i as the count.
   */
  private static void increment(State s, LongWritable o, long i) {
    if (s.counts == null) {
      s.counts = new HashMap();
    }
    LongWritable count = s.counts.get(o);
    if (count == null) {
      // We have to create a new object, because the object o belongs
      // to the code that creates it and may get its value changed.
      LongWritable key = new LongWritable();
      key.set(o.get());
      s.counts.put(key, new LongWritable(i));
    } else {
      count.set(count.get() + i);
    }
  }

  /**
   * Get the percentile value.
   */
  private static double getPercentile(List> entriesList,
      double position) {
    // We may need to do linear interpolation to get the exact percentile
    long lower = (long)Math.floor(position);
    long higher = (long)Math.ceil(position);

    // Linear search since this won't take much time from the total execution anyway
    // lower has the range of [0 .. total-1]
    // The first entry with accumulated count (lower+1) corresponds to the lower position.
    int i = 0;
    while (entriesList.get(i).getValue().get() < lower + 1) {
      i++;
    }

    long lowerKey = entriesList.get(i).getKey().get();
    if (higher == lower) {
      // no interpolation needed because position does not have a fraction
      return lowerKey;
    }

    if (entriesList.get(i).getValue().get() < higher + 1) {
      i++;
    }
    long higherKey = entriesList.get(i).getKey().get();

    if (higherKey == lowerKey) {
      // no interpolation needed because lower position and higher position has the same key
      return lowerKey;
    }

    // Linear interpolation to get the exact percentile
    return (higher - position) * lowerKey + (position - lower) * higherKey;
  }


  /**
   * The evaluator for percentile computation based on long.
   */
  public static class PercentileLongEvaluator implements UDAFEvaluator {

    private final State state;

    public PercentileLongEvaluator() {
      state = new State();
    }

    public void init() {
      if (state.counts != null) {
        // We reuse the same hashmap to reduce new object allocation.
        // This means counts can be empty when there is no input data.
        state.counts.clear();
      }
    }

    /** Note that percentile can be null in a global aggregation with
     *  0 input rows:  "select percentile(col, 0.5) from t where false"
     *  In that case, iterate(null, null) will be called once.
     */
    public boolean iterate(LongWritable o, Double percentile) {
      if (o == null && percentile == null) {
        return false;
      }
      if (state.percentiles == null) {
        if (percentile < 0.0 || percentile > 1.0) {
          throw new RuntimeException("Percentile value must be within the range of 0 to 1.");
        }
        state.percentiles = new ArrayList(1);
        state.percentiles.add(new DoubleWritable(percentile.doubleValue()));
      }
      if (o != null) {
        increment(state, o, 1);
      }
      return true;
    }

    public State terminatePartial() {
      return state;
    }

    public boolean merge(State other) {
      if (other == null || other.counts == null || other.percentiles == null) {
        return false;
      }

      if (state.percentiles == null) {
        state.percentiles = new ArrayList(other.percentiles);
      }

      for (Map.Entry e: other.counts.entrySet()) {
        increment(state, e.getKey(), e.getValue().get());
      }
      return true;
    }

    private DoubleWritable result;

    public DoubleWritable terminate() {
      // No input data.
      if (state.counts == null || state.counts.size() == 0) {
        return null;
      }

      // Get all items into an array and sort them.
      Set> entries = state.counts.entrySet();
      List> entriesList =
        new ArrayList>(entries);
      Collections.sort(entriesList, new MyComparator());

      // Accumulate the counts.
      long total = 0;
      for (int i = 0; i < entriesList.size(); i++) {
        LongWritable count = entriesList.get(i).getValue();
        total += count.get();
        count.set(total);
      }

      // Initialize the result.
      if (result == null) {
        result = new DoubleWritable();
      }

      // maxPosition is the 1.0 percentile
      long maxPosition = total - 1;
      double position = maxPosition * state.percentiles.get(0).get();
      result.set(getPercentile(entriesList, position));
      return result;
    }
  }

  /**
   * The evaluator for percentile computation based on long for an array of percentiles.
   */
  public static class PercentileLongArrayEvaluator implements UDAFEvaluator {

    private final State state;

    public PercentileLongArrayEvaluator() {
      state = new State();
    }

    public void init() {
      if (state.counts != null) {
        // We reuse the same hashmap to reduce new object allocation.
        // This means counts can be empty when there is no input data.
        state.counts.clear();
      }
    }

    public boolean iterate(LongWritable o, List percentiles) {
      if (state.percentiles == null) {
        if(percentiles != null) {
          for (int i = 0; i < percentiles.size(); i++) {
            if (percentiles.get(i).get() < 0.0 || percentiles.get(i).get() > 1.0) {
              throw new RuntimeException("Percentile value must be within the range of 0 to 1.");
            }
          }
          state.percentiles = new ArrayList(percentiles);
        }
        else {
          state.percentiles = new ArrayList();
        }
      }
      if (o != null) {
        increment(state, o, 1);
      }
      return true;
    }

    public State terminatePartial() {
      return state;
    }

    public boolean merge(State other) {
      if (other == null || other.counts == null || other.percentiles == null) {
        return true;
      }

      if (state.percentiles == null) {
        state.percentiles = new ArrayList(other.percentiles);
      }

      for (Map.Entry e: other.counts.entrySet()) {
        increment(state, e.getKey(), e.getValue().get());
      }
      return true;
    }


    private List results;

    public List terminate() {
      // No input data
      if (state.counts == null || state.counts.size() == 0) {
        return null;
      }

      // Get all items into an array and sort them
      Set> entries = state.counts.entrySet();
      List> entriesList =
        new ArrayList>(entries);
      Collections.sort(entriesList, new MyComparator());

      // accumulate the counts
      long total = 0;
      for (int i = 0; i < entriesList.size(); i++) {
        LongWritable count = entriesList.get(i).getValue();
        total += count.get();
        count.set(total);
      }

      // maxPosition is the 1.0 percentile
      long maxPosition = total - 1;

      // Initialize the results
      if (results == null) {
        results = new ArrayList();
        for (int i = 0; i < state.percentiles.size(); i++) {
          results.add(new DoubleWritable());
        }
      }
      // Set the results
      for (int i = 0; i < state.percentiles.size(); i++) {
        double position = maxPosition * state.percentiles.get(i).get();
        results.get(i).set(getPercentile(entriesList, position));
      }
      return results;
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy