All Downloads are FREE. Search and download functionalities are using the official Maven repository.

brickhouse.udf.collect.UnionUDAF Maven / Gradle / Ivy

package brickhouse.udf.collect;
/**
 * Copyright 2012 Klout, Inc
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 **/


import java.util.HashMap;
import java.util.Map;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;


@Description(name="union",
value = "_FUNC_(x) - Returns a map which contains the union of an aggregation of maps " 
)
public class UnionUDAF extends AbstractGenericUDAFResolver {


	/// Snarfed from Hives CollectSet UDAF

	@Override
	public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters)
			throws SemanticException {
		// TODO Auto-generated method stub
		if (parameters.length != 1 ) {
			throw new UDFArgumentTypeException(parameters.length - 1,
					"One argument is expected; either a map or an array.");
		}
		TypeInfo paramType = parameters[0];
		if( paramType.getCategory() == Category.MAP ) {
			return new MapUnionUDAFEvaluator();
		} else {
			//// Only maps for now
			throw new UDFArgumentTypeException(0, " Only maps supported for now ");
			///return new ArrayUnionUDAFEvaluator();
		}
	}


	public static class MapUnionUDAFEvaluator extends GenericUDAFEvaluator {
		// For PARTIAL1 and COMPLETE: ObjectInspectors for original data
		private MapObjectInspector  inputMapOI;
		private ObjectInspector inputKeyOI;
		private ObjectInspector inputValOI;
		// For PARTIAL2 and FINAL: ObjectInspectors for partial aggregations (list
		// of objs)
		private StandardMapObjectInspector moi;
		private StandardMapObjectInspector internalMergeOI;


		static class MapAggBuffer implements AggregationBuffer {
			HashMap collectMap = new HashMap();
		}

		public ObjectInspector init(Mode m, ObjectInspector[] parameters)
				throws HiveException {
			super.init(m, parameters);
			// init output object inspectors
			// The output of a partial aggregation is a list
			if (m == Mode.PARTIAL1) {
				inputMapOI = (MapObjectInspector) parameters[0];
				
				inputKeyOI = inputMapOI.getMapKeyObjectInspector();
				inputValOI = inputMapOI.getMapValueObjectInspector();
				
				return ObjectInspectorFactory.getStandardMapObjectInspector(
						   ObjectInspectorUtils.getStandardObjectInspector(inputKeyOI),
						   ObjectInspectorUtils.getStandardObjectInspector(inputValOI) );
			} else {
				if (!(parameters[0] instanceof StandardMapObjectInspector)) {
					inputKeyOI = (PrimitiveObjectInspector)  ObjectInspectorUtils
							.getStandardObjectInspector(parameters[0]);
					inputValOI = ObjectInspectorUtils
							.getStandardObjectInspector(parameters[0]);
					return (StandardMapObjectInspector) ObjectInspectorFactory
							.getStandardMapObjectInspector(inputKeyOI, inputValOI);
				} else {
					internalMergeOI = (StandardMapObjectInspector) parameters[0];
					inputKeyOI = (PrimitiveObjectInspector) internalMergeOI.getMapKeyObjectInspector();
					inputValOI = internalMergeOI.getMapValueObjectInspector();
					moi =  (StandardMapObjectInspector) ObjectInspectorUtils.getStandardObjectInspector(internalMergeOI);
					return moi;
				}
			}
		}

		@Override
		public AggregationBuffer getNewAggregationBuffer() throws HiveException {
			AggregationBuffer buff= new MapAggBuffer();
			reset(buff);
			return buff;
		}

		@Override
		public void iterate(AggregationBuffer agg, Object[] parameters)
				throws HiveException {
			Object mpObj = parameters[0];

			if (mpObj != null) {
				MapAggBuffer myagg = (MapAggBuffer) agg;
				Map mp = inputMapOI.getMap( mpObj);
				for( Object k : mp.keySet() ) {
				   Object v = mp.get(k);
				   putIntoSet(k, v, myagg);
				}
			}
		}

		@Override
		public void merge(AggregationBuffer agg, Object partial)
				throws HiveException {
			MapAggBuffer myagg = (MapAggBuffer) agg;
			HashMap partialResult = (HashMap)  internalMergeOI.getMap(partial);
			for(Object i : partialResult.keySet()) {
				putIntoSet(i, partialResult.get(i), myagg);
			}
		}

		@Override
		public void reset(AggregationBuffer buff) throws HiveException {
			MapAggBuffer arrayBuff = (MapAggBuffer) buff;
			arrayBuff.collectMap = new HashMap();
		}

		@Override
		public Object terminate(AggregationBuffer agg) throws HiveException {
			MapAggBuffer myagg = (MapAggBuffer) agg;
			HashMap ret = new HashMap(myagg.collectMap);
			return ret;

		}

		private void putIntoSet(Object key, Object val, MapAggBuffer myagg) {
			Object keyCopy = ObjectInspectorUtils.copyToStandardObject(key, this.inputKeyOI );
			Object valCopy = ObjectInspectorUtils.copyToStandardObject(val, this.inputValOI );
			
			myagg.collectMap.put(keyCopy, valCopy);
		}

		@Override
		public Object terminatePartial(AggregationBuffer agg) throws HiveException {
			MapAggBuffer myagg = (MapAggBuffer) agg;
			HashMap ret = new HashMap(myagg.collectMap);
			return ret;
		}
	}


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy