All Downloads are FREE. Search and download functionalities are using the official Maven repository.

brickhouse.udf.timeseries.VectorUnionSumUDAF Maven / Gradle / Ivy

The newest version!
package brickhouse.udf.timeseries;
/**
 * Copyright 2012 Klout, Inc
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 **/


/**
 * Similar to Ruby collect, 
 *   return an array with all the values
 */

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;


@Description(name="union_vector_sum",
value = "_FUNC_(x) - Aggregate adding vectors together " 
)
public class VectorUnionSumUDAF extends AbstractGenericUDAFResolver {


	/// Snarfed from Hives CollectSet UDAF

	@Override
	public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
			throws SemanticException {
		// TODO Auto-generated method stub
		if (parameters.length != 1 ) {
			throw new UDFArgumentTypeException(parameters.length - 1,
					"Vector sum takes one argument");
		}
		if(parameters[0].getCategory() == Category.LIST) {
			return new VectorArraySumUDAFEvaluator();
		} else if( parameters[0].getCategory() == Category.MAP) {
			return new VectorMapSumUDAFEvaluator();
		} else {
			throw new UDFArgumentTypeException(0," vector_union_sum aggregates either arrays or maps");
		}
	}

	public static class VectorArraySumUDAFEvaluator extends GenericUDAFEvaluator {
		// For PARTIAL1 and COMPLETE: ObjectInspectors for original data, an array
		private ListObjectInspector inputOI;
		// For PARTIAL2 and FINAL: ObjectInspectors for partial aggregations (list
		//  ( sum of arrays, or arrays) 
		private StandardListObjectInspector stdListOI;


		static class VectorArrayAggBuffer implements AggregationBuffer {
			ArrayList sumArray = new ArrayList();
		}

		public ObjectInspector init(Mode m, ObjectInspector[] parameters)
				throws HiveException {
			super.init(m, parameters);
				inputOI = (ListObjectInspector) parameters[0];
				if(inputOI.getListElementObjectInspector().getCategory() != Category.PRIMITIVE
						|| !NumericUtil.isNumericCategory(
								((PrimitiveObjectInspector)inputOI.getListElementObjectInspector()).getPrimitiveCategory())) {
					throw new HiveException("Vector values must be numeric.");
				}
				/// always return the standard list of doubles
				stdListOI =  ObjectInspectorFactory
						.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaDoubleObjectInspector);
				return stdListOI;
		}

		@Override
		public AggregationBuffer getNewAggregationBuffer() throws HiveException {
			AggregationBuffer buff= new VectorArrayAggBuffer();
			reset(buff);
			return buff;
		}

		@Override
		public void iterate(AggregationBuffer agg, Object[] parameters)
				throws HiveException {
			Object p = parameters[0];

			if (p != null) {
				VectorArrayAggBuffer myagg = (VectorArrayAggBuffer) agg;
				 addVector(p, myagg, inputOI);
			}
		}

		@Override
		public void merge(AggregationBuffer agg, Object partial)
				throws HiveException {
			VectorArrayAggBuffer myagg = (VectorArrayAggBuffer) agg;
			addVector( partial, myagg, this.inputOI);
		}

		@Override
		public void reset(AggregationBuffer buff) throws HiveException {
			VectorArrayAggBuffer arrayBuff = (VectorArrayAggBuffer) buff;
			arrayBuff.sumArray = new ArrayList();
		}

		@Override
		public Object terminate(AggregationBuffer agg) throws HiveException {
			VectorArrayAggBuffer myagg = (VectorArrayAggBuffer) agg;
			return myagg.sumArray;
		}

		private void addVector(Object listObj, VectorArrayAggBuffer myagg, ListObjectInspector inputOI) {
			int listLen = inputOI.getListLength(listObj);
			if( listLen > myagg.sumArray.size() )
				myagg.sumArray.ensureCapacity(listLen);
			
			for(int i=0; i sumMap = new HashMap();
		}

		public ObjectInspector init(Mode m, ObjectInspector[] parameters)
				throws HiveException {
			super.init(m, parameters);
				inputOI = (MapObjectInspector) parameters[0];
				if(inputOI.getMapKeyObjectInspector().getCategory() != Category.PRIMITIVE) {
					throw new HiveException("Vector map keys must be a primitive.");
				}
				if(inputOI.getMapValueObjectInspector().getCategory() != Category.PRIMITIVE
						|| !NumericUtil.isNumericCategory(
								((PrimitiveObjectInspector)inputOI.getMapValueObjectInspector()).getPrimitiveCategory())) {
					throw new HiveException("Vector values must be numeric.");
				}
				stdMapOI =  ObjectInspectorFactory.
						getStandardMapObjectInspector(
								ObjectInspectorUtils.getStandardObjectInspector( inputOI.getMapKeyObjectInspector(),
										ObjectInspectorUtils.ObjectInspectorCopyOption.JAVA),
								PrimitiveObjectInspectorFactory.javaDoubleObjectInspector);
				   			//// XXX make return type  numeric type of input,
				            //// not doubles...
				return stdMapOI;
		}

		@Override
		public AggregationBuffer getNewAggregationBuffer() throws HiveException {
			AggregationBuffer buff= new VectorMapAggBuffer();
			reset(buff);
			return buff;
		}

		@Override
		public void iterate(AggregationBuffer agg, Object[] parameters)
				throws HiveException {
			Object p = parameters[0];

			if (p != null) {
				VectorMapAggBuffer myagg = (VectorMapAggBuffer) agg;
				 addVectorMap(p, myagg, inputOI);
			}
		}

		@Override
		public void merge(AggregationBuffer agg, Object partial)
				throws HiveException {
			VectorMapAggBuffer myagg = (VectorMapAggBuffer) agg;
			addVectorMap( partial, myagg, this.inputOI);
		}

		@Override
		public void reset(AggregationBuffer buff) throws HiveException {
			VectorMapAggBuffer arrayBuff = (VectorMapAggBuffer) buff;
			arrayBuff.sumMap = new HashMap();
		}

		@Override
		public Object terminate(AggregationBuffer agg) throws HiveException {
			VectorMapAggBuffer myagg = (VectorMapAggBuffer) agg;
			return myagg.sumMap;
		}

		private void addVectorMap(Object mapObj, VectorMapAggBuffer myagg, MapObjectInspector inputOI) {
			Map uninspMap = inputOI.getMap( mapObj);
			for( Object uninspKey : uninspMap.keySet()) {
			    Object stdKey = ObjectInspectorUtils.copyToStandardJavaObject(uninspKey, 
			    		inputOI.getMapKeyObjectInspector());
			    
			    double stdVal = NumericUtil.getNumericValue((PrimitiveObjectInspector)inputOI.getMapValueObjectInspector(), uninspMap.get(uninspKey));
			    if( myagg.sumMap.containsKey(stdKey)) {
			    	double prevVal = myagg.sumMap.get(stdKey);
			    	myagg.sumMap.put( stdKey, prevVal + stdVal);
			    } else {
			    	myagg.sumMap.put( stdKey, stdVal);
			    }
			}
		}	
			

		@Override
		public Object terminatePartial(AggregationBuffer agg) throws HiveException {
			VectorMapAggBuffer myagg = (VectorMapAggBuffer) agg;
			return myagg.sumMap;
		}
	}


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy