All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hadoop.hive.ql.udf.generic.GenericUDFRound Maven / Gradle / Ivy

There is a newer version: 1.21.1
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.hive.ql.udf.generic;

import org.apache.hadoop.hive.common.type.HiveDecimal;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.exec.vector.VectorizedExpressions;
import org.apache.hadoop.hive.ql.exec.vector.expressions.FuncRoundWithNumDigitsDecimalToDecimal;
import org.apache.hadoop.hive.ql.exec.vector.expressions.RoundWithNumDigitsDoubleToDouble;
import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FuncRoundDecimalToDecimal;
import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FuncRoundDoubleToDouble;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.io.ByteWritable;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
import org.apache.hadoop.hive.serde2.io.ShortWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantByteObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantIntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantLongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantShortObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;

/**
 * Note: rounding function permits rounding off integer digits in decimal numbers, which essentially
 * downgrades the scale to negative territory. However, Hive decimal only allow non-negative scales.
 * This can cause confusion, for example, when a decimal number 1234.567 of type decimal(7,3) is
 * rounded with -2 scale, which produces a decimal number, 1200. The type of the output is
 * decimal(5,0), which does not exactly represents what the number means. Thus, user should
 * be aware of this, and use negative rounding for decimal with caution. However, Hive is in line
 * with the behavior that MYSQL demonstrates.
 *
 * At a certain point, we should probably support negative scale for decimal type.
 *
 * GenericUDFRound.
 *
 */
@Description(name = "round",
  value = "_FUNC_(x[, d]) - round x to d decimal places",
  extended = "Example:\n"
    + "  > SELECT _FUNC_(12.3456, 1) FROM src LIMIT 1;\n" + "  12.3'")
@VectorizedExpressions({FuncRoundDoubleToDouble.class, RoundWithNumDigitsDoubleToDouble.class,
    FuncRoundWithNumDigitsDecimalToDecimal.class, FuncRoundDecimalToDecimal.class})
public class GenericUDFRound extends GenericUDF {
  private transient PrimitiveObjectInspector inputOI;
  private int scale = 0;

  private transient PrimitiveCategory inputType;
  private transient Converter converterFromString;
  private transient boolean constantScale = true;
  private transient PrimitiveObjectInspector scaleOI;

  @Override
  public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
    if (arguments.length < 1 || arguments.length > 2) {
      throw new UDFArgumentLengthException(
          "ROUND requires one or two argument, got " + arguments.length);
    }

    if (arguments[0].getCategory() != Category.PRIMITIVE) {
      throw new UDFArgumentTypeException(0,
          "ROUND input only takes primitive types, got " + arguments[0].getTypeName());
    }
    inputOI = (PrimitiveObjectInspector) arguments[0];

    if (arguments.length == 2) {
      if (arguments[1].getCategory() != Category.PRIMITIVE) {
        throw new UDFArgumentTypeException(1,
            "ROUND second argument only takes primitive types, got " + arguments[1].getTypeName());
      }
      scaleOI = (PrimitiveObjectInspector) arguments[1];
      switch (scaleOI.getPrimitiveCategory()) {
      case VOID:
        break;
      case BYTE:
        if (scaleOI instanceof WritableConstantByteObjectInspector) {
          scale = ((WritableConstantByteObjectInspector)scaleOI).getWritableConstantValue().get();
        } else {
          constantScale = false;
        }
        break;
      case SHORT:
        if (scaleOI instanceof WritableConstantShortObjectInspector) {
          scale = ((WritableConstantShortObjectInspector)scaleOI).getWritableConstantValue().get();
        } else {
          constantScale = false;
        }
        break;
      case INT:
        if (scaleOI instanceof WritableConstantIntObjectInspector) {
         scale = ((WritableConstantIntObjectInspector)scaleOI).getWritableConstantValue().get();
        } else {
          constantScale = false;
        }
        break;
      case LONG:
        if (scaleOI instanceof WritableConstantLongObjectInspector) {
          long l = ((WritableConstantLongObjectInspector)scaleOI).getWritableConstantValue().get();
          if (l < Integer.MIN_VALUE || l > Integer.MAX_VALUE) {
            throw new UDFArgumentException(getFuncName().toUpperCase()
                + " scale argument out of allowed range");
          }
          scale = (int)l;
        } else {
          constantScale = false;
        }
        break;
      default:
        throw new UDFArgumentTypeException(1, getFuncName().toUpperCase()
            + " second argument only takes numeric type");
      }
    }

    inputType = inputOI.getPrimitiveCategory();
    ObjectInspector outputOI = null;
    switch (inputType) {
    case DECIMAL:
      DecimalTypeInfo inputTypeInfo = (DecimalTypeInfo) inputOI.getTypeInfo();
      DecimalTypeInfo typeInfo = getOutputTypeInfo(inputTypeInfo, scale);
      outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(typeInfo);
      if (!constantScale) {
        throw new UDFArgumentTypeException(1,getFuncName().toUpperCase() + " scale argument for "
            + "decimal must be constant");
      }
      break;
    case VOID:
    case BYTE:
    case SHORT:
    case INT:
    case LONG:
    case FLOAT:
    case DOUBLE:
      outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(inputType);
      break;
    case STRING:
    case VARCHAR:
    case CHAR:
      outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(PrimitiveCategory.DOUBLE);
      converterFromString = ObjectInspectorConverters.getConverter(inputOI, outputOI);
      break;
    default:
      throw new UDFArgumentTypeException(0,
          "Only numeric or string group data types are allowed for ROUND function. Got "
              + inputType.name());
    }

    return outputOI;
  }

  private static DecimalTypeInfo getOutputTypeInfo(DecimalTypeInfo inputTypeInfo, int dec) {
    int prec = inputTypeInfo.precision();
    int scale = inputTypeInfo.scale();
    int intParts = prec - scale;

    // If we are rounding, we may introduce one more integer digit.
    int newIntParts = dec < scale ? intParts + 1 : intParts;

    int newScale = dec < 0 ? 0 : Math.min(dec, HiveDecimal.MAX_SCALE);

    int newPrec = Math.min(newIntParts + newScale, HiveDecimal.MAX_PRECISION);

    return TypeInfoFactory.getDecimalTypeInfo(newPrec, newScale);
  }

  @Override
  public Object evaluate(DeferredObject[] arguments) throws HiveException {
    if (arguments.length == 2 && (arguments[1] == null || arguments[1].get() == null)) {
      return null;
    }

    if (arguments[0] == null) {
      return null;
    }

    Object input = arguments[0].get();
    if (input == null) {
      return null;
    }

    switch (inputType) {
    case VOID:
      return null;
    case DECIMAL:
      {
        // The getPrimitiveWritableObject method returns a new writable.
        HiveDecimalWritable decimalWritable = (HiveDecimalWritable) inputOI.getPrimitiveWritableObject(input);
        // Call the different round flavor.
        return round(decimalWritable, scale);
      }
    case BYTE:
      ByteWritable byteWritable = (ByteWritable)inputOI.getPrimitiveWritableObject(input);
      if (!constantScale) {
        scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue();
      }
      if (scale >= 0) {
        return byteWritable;
      } else {
        return new ByteWritable((byte)round(byteWritable.get(), scale));
      }
    case SHORT:
      ShortWritable shortWritable = (ShortWritable)inputOI.getPrimitiveWritableObject(input);
      if (!constantScale) {
        scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue();
      }
      if (scale >= 0) {
        return shortWritable;
      } else {
        return new ShortWritable((short)round(shortWritable.get(), scale));
      }
    case INT:
      IntWritable intWritable = (IntWritable)inputOI.getPrimitiveWritableObject(input);
      if (!constantScale) {
        scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue();
      }
      if (scale >= 0) {
        return intWritable;
      } else {
        return new IntWritable((int)round(intWritable.get(), scale));
      }
    case LONG:
      LongWritable longWritable = (LongWritable)inputOI.getPrimitiveWritableObject(input);
      if (!constantScale) {
        scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue();
      }
      if (scale >= 0) {
        return longWritable;
      } else {
        return new LongWritable(round(longWritable.get(), scale));
      }
    case FLOAT:
      float f = ((FloatWritable)inputOI.getPrimitiveWritableObject(input)).get();
      if (!constantScale) {
        scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue();
      }
      return new FloatWritable((float)round(f, scale));
     case DOUBLE:
       if (!constantScale) {
         scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue();
       }
       return round(((DoubleWritable)inputOI.getPrimitiveWritableObject(input)), scale);
    case STRING:
    case VARCHAR:
    case CHAR:
       DoubleWritable doubleValue = (DoubleWritable) converterFromString.convert(input);
       if (doubleValue == null) {
         return null;
       }
       if (!constantScale) {
         scale = ((Number)scaleOI.getPrimitiveJavaObject(arguments[1].get())).intValue();
       }
       return round(doubleValue, scale);
     default:
      throw new UDFArgumentTypeException(0,
          "Only numeric or string group data types are allowed for ROUND function. Got "
              + inputType.name());
    }
  }

  protected HiveDecimalWritable round(HiveDecimalWritable inputDecWritable, int scale) {
    HiveDecimalWritable result = new HiveDecimalWritable(inputDecWritable);
    result.mutateSetScale(scale, HiveDecimal.ROUND_HALF_UP);
    return result;
  }

  protected long round(long input, int scale) {
    return RoundUtils.round(input, scale);
  }

  protected double round(double input, int scale) {
    return RoundUtils.round(input, scale);
  }

  protected DoubleWritable round(DoubleWritable input, int scale) {
    double d = input.get();
    if (Double.isNaN(d) || Double.isInfinite(d)) {
      return new DoubleWritable(d);
    } else {
      return new DoubleWritable(RoundUtils.round(d, scale));
    }
  }

  @Override
  public String getDisplayString(String[] children) {
    return getStandardDisplayString("round", children);
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy