All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.sketch.hll.ApproxCountDistinctUDAF Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.sketch.hll;

import hivemall.UDAFEvaluatorWithOptions;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.Primitives;

import java.io.IOException;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;

import com.clearspring.analytics.stream.cardinality.CardinalityMergeException;
import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus;

@Description(name = "approx_count_distinct", value = "_FUNC_(expr x [, const string options])"
        + " - Returns an approximation of count(DISTINCT x) using HyperLogLogPlus algorithm")
public final class ApproxCountDistinctUDAF extends AbstractGenericUDAFResolver {

    @Override
    public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo)
            throws SemanticException {
        if (typeInfo.length != 1 && typeInfo.length != 2) {
            throw new UDFArgumentTypeException(typeInfo.length - 1,
                "_FUNC_ takes one or two arguments");
        }
        if (typeInfo.length == 2 && !HiveUtils.isStringTypeInfo(typeInfo[1])) {
            throw new UDFArgumentTypeException(1,
                "The second argument type expected to be const string: " + typeInfo[1]);
        }

        return new HLLEvaluator();
    }

    public static final class HLLEvaluator extends UDAFEvaluatorWithOptions {

        @Nullable
        private int[] params;

        private ObjectInspector origInputOI;
        private BinaryObjectInspector mergeInputOI;

        @Override
        protected Options getOptions() {
            Options opts = new Options();
            opts.addOption("p", true,
                "The size of registers for the normal set. `p` MUST be in the range [4,sp] and 15 by the default");
            opts.addOption("sp", true,
                "The size of registers for the sparse set. `sp` MUST be in the range [4,32] and 25 by the default");
            return opts;
        }

        @Override
        protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs)
                throws UDFArgumentException {
            CommandLine cl = null;

            int p = 15, sp = 25;
            if (argOIs.length == 2) {
                if (!HiveUtils.isConstString(argOIs[1])) {
                    throw new UDFArgumentException(
                        "The second argument type expected to be const string: " + argOIs[1]);
                }
                cl = parseOptions(HiveUtils.getConstString(argOIs[1]));

                p = Primitives.parseInt(cl.getOptionValue("p"), p);
                sp = Primitives.parseInt(cl.getOptionValue("sp"), sp);
                validateArguments(p, sp);
            }

            this.params = new int[] {p, sp};

            return cl;
        }

        @Override
        public ObjectInspector init(@Nonnull Mode mode, @Nonnull ObjectInspector[] parameters)
                throws HiveException {
            assert (parameters.length == 1 || parameters.length == 2) : parameters.length;
            super.init(mode, parameters);

            // initialize input
            if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
                processOptions(parameters);
                this.origInputOI = parameters[0];
            } else {// from partial aggregation
                this.mergeInputOI = HiveUtils.asBinaryOI(parameters[0]);
            }

            // initialize output
            final ObjectInspector outputOI;
            if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
                outputOI = PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector;
            } else {// terminate
                outputOI = PrimitiveObjectInspectorFactory.writableLongObjectInspector;
            }
            return outputOI;
        }

        @Override
        public HLLBuffer getNewAggregationBuffer() throws HiveException {
            HLLBuffer buf = new HLLBuffer();
            if (params != null) {
                buf.reset(params[0], params[1]);
            }
            return buf;
        }

        @SuppressWarnings("deprecation")
        @Override
        public void reset(@Nonnull AggregationBuffer agg) throws HiveException {
            HLLBuffer buf = (HLLBuffer) agg;
            if (params != null) {
                buf.reset(params[0], params[1]);
            } else {
                buf.hll = null;
            }
        }

        @SuppressWarnings("deprecation")
        @Override
        public void iterate(@Nonnull AggregationBuffer agg, @Nonnull Object[] parameters)
                throws HiveException {
            if (parameters[0] == null) {
                return;
            }

            HLLBuffer buf = (HLLBuffer) agg;
            Object value =
                    ObjectInspectorUtils.copyToStandardJavaObject(parameters[0], origInputOI);
            Preconditions.checkNotNull(buf.hll, HiveException.class);
            buf.hll.offer(value);
        }

        @SuppressWarnings("deprecation")
        @Override
        @Nullable
        public byte[] terminatePartial(@Nonnull AggregationBuffer agg) throws HiveException {
            HLLBuffer buf = (HLLBuffer) agg;
            if (buf.hll == null) {
                return null;
            }
            try {
                return buf.hll.getBytes();
            } catch (IOException e) {
                throw new HiveException(e);
            }
        }

        @SuppressWarnings("deprecation")
        @Override
        public void merge(@Nonnull AggregationBuffer agg, @Nullable Object partial)
                throws HiveException {
            if (partial == null) {
                return;
            }

            byte[] data = mergeInputOI.getPrimitiveJavaObject(partial);
            final HyperLogLogPlus otherHLL;
            try {
                otherHLL = HyperLogLogPlus.Builder.build(data);
            } catch (IOException e) {
                throw new HiveException("Failed to build other HLL");
            }

            final HLLBuffer buf = (HLLBuffer) agg;
            if (buf.hll == null) {
                buf.hll = otherHLL;
            } else {
                try {
                    buf.hll.addAll(otherHLL);
                } catch (CardinalityMergeException e) {
                    throw new HiveException("Failed to merge HLL");
                }
            }
        }

        @SuppressWarnings("deprecation")
        @Override
        public LongWritable terminate(@Nonnull AggregationBuffer agg) throws HiveException {
            HLLBuffer buf = (HLLBuffer) agg;

            long cardinality = (buf.hll == null) ? 0L : buf.hll.cardinality();
            return new LongWritable(cardinality);
        }

    }

    private static void validateArguments(final int p, final int sp) throws UDFArgumentException {
        if (p < 4 || p > sp) {
            throw new UDFArgumentException("p must be between 4 and sp (inclusive)");
        }
        if (sp > 32) {
            throw new UDFArgumentException("sp values greater than 32 not supported");
        }
    }

    @AggregationType(estimable = true)
    static final class HLLBuffer extends AbstractAggregationBuffer {

        @Nullable
        private HyperLogLogPlus hll;

        HLLBuffer() {}

        @Override
        public int estimate() {
            return (hll == null) ? 0 : hll.sizeof();
        }

        void reset(@Nonnegative int p, @Nonnegative int sp) {
            this.hll = new HyperLogLogPlus(p, sp);
        }

    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy