All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.hadoop.hive.ql.udf.generic.GenericUDFTrunc Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.hive.ql.udf.generic;

import java.math.BigDecimal;

import org.apache.hadoop.hive.common.type.Date;
import org.apache.hadoop.hive.common.type.HiveDecimal;
import org.apache.hadoop.hive.common.type.Timestamp;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.io.ByteWritable;
import org.apache.hadoop.hive.serde2.io.DateWritableV2;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
import org.apache.hadoop.hive.serde2.io.ShortWritable;
import org.apache.hadoop.hive.serde2.io.TimestampWritableV2;
import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorConverter.TimestampConverter;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils.PrimitiveGrouping;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;

/**
 * GenericUDFTrunc.
 *
 * Returns the first day of the month which the date belongs to. The time part of the date will be
 * ignored.
 *
 */
@Description(name = "trunc", value = "_FUNC_(date, fmt) / _FUNC_(N,D) - Returns If input is date returns date with the time portion of the day truncated "
    + "to the unit specified by the format model fmt. If you omit fmt, then date is truncated to "
    + "the nearest day. It currently only supports 'MONTH'/'MON'/'MM', 'QUARTER'/'Q' and 'YEAR'/'YYYY'/'YY' as format."
    + "If input is a number group returns N truncated to D decimal places. If D is omitted, then N is truncated to 0 places."
    + "D can be negative to truncate (make zero) D digits left of the decimal point."
    , extended = "date is a string in the format 'yyyy-MM-dd HH:mm:ss' or 'yyyy-MM-dd'."
        + " The time part of date is ignored.\n" + "Example:\n "
        + " > SELECT _FUNC_('2009-02-12', 'MM');\n" + "OK\n" + " '2009-02-01'" + "\n"
        + " > SELECT _FUNC_('2017-03-15', 'Q');\n" + "OK\n" + " '2017-01-01'" + "\n"
        + " > SELECT _FUNC_('2015-10-27', 'YEAR');\n" + "OK\n" + " '2015-01-01'"
        + " > SELECT _FUNC_(1234567891.1234567891,4);\n" + "OK\n" + " 1234567891.1234" + "\n"
        + " > SELECT _FUNC_(1234567891.1234567891,-4);\n" + "OK\n" + " 1234560000"
        + " > SELECT _FUNC_(1234567891.1234567891,0);\n" + "OK\n" + " 1234567891" + "\n"
        + " > SELECT _FUNC_(1234567891.1234567891);\n" + "OK\n" + " 1234567891")
public class GenericUDFTrunc extends GenericUDF {

  private transient TimestampConverter timestampConverter;
  private transient Converter textConverter1;
  private transient Converter textConverter2;
  private transient Converter dateWritableConverter;
  private transient Converter byteConverter;
  private transient Converter shortConverter;
  private transient Converter intConverter;
  private transient Converter longConverter;
  private transient PrimitiveCategory inputType1;
  private transient PrimitiveCategory inputType2;
  private final Date date = new Date();
  private final Text output = new Text();
  private transient String fmtInput;
  private transient PrimitiveObjectInspector inputOI;
  private transient PrimitiveObjectInspector inputScaleOI;
  private int scale = 0;
  private boolean inputSacleConst;
  private boolean dateTypeArg;

  @Override
  public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
    if (arguments.length == 2) {
      inputType1 = ((PrimitiveObjectInspector) arguments[0]).getPrimitiveCategory();
      inputType2 = ((PrimitiveObjectInspector) arguments[1]).getPrimitiveCategory();
      if ((PrimitiveObjectInspectorUtils
          .getPrimitiveGrouping(inputType1) == PrimitiveGrouping.DATE_GROUP
          || PrimitiveObjectInspectorUtils
              .getPrimitiveGrouping(inputType1) == PrimitiveGrouping.STRING_GROUP)
          && PrimitiveObjectInspectorUtils
              .getPrimitiveGrouping(inputType2) == PrimitiveGrouping.STRING_GROUP) {
        dateTypeArg = true;
        return initializeDate(arguments);
      } else if (PrimitiveObjectInspectorUtils
          .getPrimitiveGrouping(inputType1) == PrimitiveGrouping.NUMERIC_GROUP
          && PrimitiveObjectInspectorUtils
              .getPrimitiveGrouping(inputType2) == PrimitiveGrouping.NUMERIC_GROUP) {
        dateTypeArg = false;
        return initializeNumber(arguments);
      }
      throw new UDFArgumentException("Got wrong argument types : first argument type : "
          + arguments[0].getTypeName() + ", second argument type : " + arguments[1].getTypeName());
    } else if (arguments.length == 1) {
      inputType1 = ((PrimitiveObjectInspector) arguments[0]).getPrimitiveCategory();
      if (PrimitiveObjectInspectorUtils
          .getPrimitiveGrouping(inputType1) == PrimitiveGrouping.NUMERIC_GROUP) {
        dateTypeArg = false;
        return initializeNumber(arguments);
      } else {
        throw new UDFArgumentException(
            "Only primitive type arguments are accepted, when arguments length is one, got "
                + arguments[0].getTypeName());
      }
    }
    throw new UDFArgumentException("TRUNC requires one or two argument, got " + arguments.length);
  }

  private ObjectInspector initializeNumber(ObjectInspector[] arguments)
      throws UDFArgumentException {
    if (arguments.length < 1 || arguments.length > 2) {
      throw new UDFArgumentLengthException(
          "TRUNC requires one or two argument, got " + arguments.length);
    }

    if (arguments[0].getCategory() != Category.PRIMITIVE) {
      throw new UDFArgumentTypeException(0,
          "TRUNC input only takes primitive types, got " + arguments[0].getTypeName());
    }
    inputOI = (PrimitiveObjectInspector) arguments[0];

    if (arguments.length == 2) {
      if (arguments[1].getCategory() != Category.PRIMITIVE) {
        throw new UDFArgumentTypeException(1,
            "TRUNC second argument only takes primitive types, got " + arguments[1].getTypeName());
      }

      inputScaleOI = (PrimitiveObjectInspector) arguments[1];
      inputSacleConst = arguments[1] instanceof ConstantObjectInspector;
      if (inputSacleConst) {
        try {
          Object obj = ((ConstantObjectInspector) arguments[1]).getWritableConstantValue();
          fmtInput = obj != null ? obj.toString() : null;
          scale = Integer.parseInt(fmtInput);
        } catch (Exception e) {
          throw new UDFArgumentException("TRUNC input only takes integer values, got " + fmtInput);
        }
      } else {
        switch (inputScaleOI.getPrimitiveCategory()) {
        case BYTE:
          byteConverter = ObjectInspectorConverters.getConverter(arguments[1],
              PrimitiveObjectInspectorFactory.writableByteObjectInspector);
          break;
        case SHORT:
          shortConverter = ObjectInspectorConverters.getConverter(arguments[1],
              PrimitiveObjectInspectorFactory.writableShortObjectInspector);
          break;
        case INT:
          intConverter = ObjectInspectorConverters.getConverter(arguments[1],
              PrimitiveObjectInspectorFactory.writableIntObjectInspector);
          break;
        case LONG:
          longConverter = ObjectInspectorConverters.getConverter(arguments[1],
              PrimitiveObjectInspectorFactory.writableLongObjectInspector);
          break;
        default:
          throw new UDFArgumentTypeException(1,
              getFuncName().toUpperCase() + " second argument only takes integer values");
        }
      }
    }

    inputType1 = inputOI.getPrimitiveCategory();
    ObjectInspector outputOI = null;
    switch (inputType1) {
    case DECIMAL:
      outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(inputType1);
      break;
    case VOID:
    case BYTE:
    case SHORT:
    case INT:
    case LONG:
    case FLOAT:
    case DOUBLE:
      outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(inputType1);
      break;
    default:
      throw new UDFArgumentTypeException(0,
          "Only numeric or string group data types are allowed for TRUNC function. Got "
              + inputType1.name());
    }

    return outputOI;
  }

  private ObjectInspector initializeDate(ObjectInspector[] arguments)
      throws UDFArgumentLengthException, UDFArgumentTypeException {
    if (arguments.length != 2) {
      throw new UDFArgumentLengthException("trunc() requires 2 argument, got " + arguments.length);
    }

    if (arguments[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentTypeException(0, "Only primitive type arguments are accepted but "
          + arguments[0].getTypeName() + " is passed. as first arguments");
    }

    if (arguments[1].getCategory() != ObjectInspector.Category.PRIMITIVE) {
      throw new UDFArgumentTypeException(1, "Only primitive type arguments are accepted but "
          + arguments[1].getTypeName() + " is passed. as second arguments");
    }

    ObjectInspector outputOI = PrimitiveObjectInspectorFactory.writableStringObjectInspector;
    inputType1 = ((PrimitiveObjectInspector) arguments[0]).getPrimitiveCategory();
    switch (inputType1) {
    case STRING:
    case VARCHAR:
    case CHAR:
    case VOID:
      inputType1 = PrimitiveCategory.STRING;
      textConverter1 = ObjectInspectorConverters.getConverter(arguments[0],
          PrimitiveObjectInspectorFactory.writableStringObjectInspector);
      break;
    case TIMESTAMP:
      timestampConverter = new TimestampConverter((PrimitiveObjectInspector) arguments[0],
          PrimitiveObjectInspectorFactory.writableTimestampObjectInspector);
      break;
    case DATE:
      dateWritableConverter = ObjectInspectorConverters.getConverter(arguments[0],
          PrimitiveObjectInspectorFactory.writableDateObjectInspector);
      break;
    default:
      throw new UDFArgumentTypeException(0,
          "TRUNC() only takes STRING/TIMESTAMP/DATEWRITABLE types as first argument, got "
              + inputType1);
    }

    inputType2 = ((PrimitiveObjectInspector) arguments[1]).getPrimitiveCategory();
    if (PrimitiveObjectInspectorUtils
        .getPrimitiveGrouping(inputType2) != PrimitiveGrouping.STRING_GROUP
        && PrimitiveObjectInspectorUtils
            .getPrimitiveGrouping(inputType2) != PrimitiveGrouping.VOID_GROUP) {
      throw new UDFArgumentTypeException(1,
          "trunc() only takes STRING/CHAR/VARCHAR types as second argument, got " + inputType2);
    }

    inputType2 = PrimitiveCategory.STRING;

    if (arguments[1] instanceof ConstantObjectInspector) {
      Object obj = ((ConstantObjectInspector) arguments[1]).getWritableConstantValue();
      fmtInput = obj != null ? obj.toString() : null;
    } else {
      textConverter2 = ObjectInspectorConverters.getConverter(arguments[1],
          PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    }
    return outputOI;
  }

  @Override
  public Object evaluate(DeferredObject[] arguments) throws HiveException {
    if (dateTypeArg) {
      return evaluateDate(arguments);
    } else {
      return evaluateNumber(arguments);
    }
  }

  private Object evaluateDate(DeferredObject[] arguments) throws UDFArgumentLengthException,
      HiveException, UDFArgumentTypeException, UDFArgumentException {
    if (arguments.length != 2) {
      throw new UDFArgumentLengthException("trunc() requires 2 argument, got " + arguments.length);
    }

    if (arguments[0].get() == null || arguments[1].get() == null) {
      return null;
    }

    if (textConverter2 != null) {
      fmtInput = textConverter2.convert(arguments[1].get()).toString();
    }

    Date d;
    switch (inputType1) {
    case STRING:
      String dateString = textConverter1.convert(arguments[0].get()).toString();
      try {
        d = Date.valueOf(dateString.toString());
      } catch (IllegalArgumentException e) {
        return null;
      }
      break;
    case TIMESTAMP:
      Timestamp ts =
          ((TimestampWritableV2) timestampConverter.convert(arguments[0].get())).getTimestamp();
      d = Date.ofEpochMilli(ts.toEpochMilli());
      break;
    case DATE:
      DateWritableV2 dw = (DateWritableV2) dateWritableConverter.convert(arguments[0].get());
      d = dw.get();
      break;
    default:
      throw new UDFArgumentTypeException(0,
          "TRUNC() only takes STRING/TIMESTAMP/DATEWRITABLE types, got " + inputType1);
    }

    if (evalDate(d) == null) {
      return null;
    }

    output.set(date.toString());
    return output;
  }

  private Object evaluateNumber(DeferredObject[] arguments)
      throws HiveException, UDFArgumentTypeException {

    if (arguments[0] == null) {
      return null;
    }

    Object input = arguments[0].get();
    if (input == null) {
      return null;
    }

    if (arguments.length == 2 && arguments[1] != null && arguments[1].get() != null
        && !inputSacleConst) {
      Object scaleObj = null;
      switch (inputScaleOI.getPrimitiveCategory()) {
      case BYTE:
        scaleObj = byteConverter.convert(arguments[1].get());
        scale = ((ByteWritable) scaleObj).get();
        break;
      case SHORT:
        scaleObj = shortConverter.convert(arguments[1].get());
        scale = ((ShortWritable) scaleObj).get();
        break;
      case INT:
        scaleObj = intConverter.convert(arguments[1].get());
        scale = ((IntWritable) scaleObj).get();
        break;
      case LONG:
        scaleObj = longConverter.convert(arguments[1].get());
        long l = ((LongWritable) scaleObj).get();
        if (l < Integer.MIN_VALUE || l > Integer.MAX_VALUE) {
          throw new UDFArgumentException(
              getFuncName().toUpperCase() + " scale argument out of allowed range");
        }
        scale = (int) l;
      default:
        break;
      }
    }

    switch (inputType1) {
    case VOID:
      return null;
    case DECIMAL:
      HiveDecimalWritable decimalWritable =
          (HiveDecimalWritable) inputOI.getPrimitiveWritableObject(input);
      HiveDecimal dec = trunc(decimalWritable.getHiveDecimal(), scale);
      if (dec == null) {
        return null;
      }
      return new HiveDecimalWritable(dec);
    case BYTE:
      ByteWritable byteWritable = (ByteWritable) inputOI.getPrimitiveWritableObject(input);
      if (scale >= 0) {
        return byteWritable;
      } else {
        return new ByteWritable((byte) trunc(byteWritable.get(), scale));
      }
    case SHORT:
      ShortWritable shortWritable = (ShortWritable) inputOI.getPrimitiveWritableObject(input);
      if (scale >= 0) {
        return shortWritable;
      } else {
        return new ShortWritable((short) trunc(shortWritable.get(), scale));
      }
    case INT:
      IntWritable intWritable = (IntWritable) inputOI.getPrimitiveWritableObject(input);
      if (scale >= 0) {
        return intWritable;
      } else {
        return new IntWritable((int) trunc(intWritable.get(), scale));
      }
    case LONG:
      LongWritable longWritable = (LongWritable) inputOI.getPrimitiveWritableObject(input);
      if (scale >= 0) {
        return longWritable;
      } else {
        return new LongWritable(trunc(longWritable.get(), scale));
      }
    case FLOAT:
      float f = ((FloatWritable) inputOI.getPrimitiveWritableObject(input)).get();
      return new FloatWritable((float) trunc(f, scale));
    case DOUBLE:
      return trunc(((DoubleWritable) inputOI.getPrimitiveWritableObject(input)), scale);
    default:
      throw new UDFArgumentTypeException(0,
          "Only numeric or string group data types are allowed for TRUNC function. Got "
              + inputType1.name());
    }
  }

  @Override
  public String getDisplayString(String[] children) {
    return getStandardDisplayString("trunc", children);
  }

  private Date evalDate(Date d) throws UDFArgumentException {
    date.setTimeInDays(d.toEpochDay());
    if ("MONTH".equals(fmtInput) || "MON".equals(fmtInput) || "MM".equals(fmtInput)) {
      date.setDayOfMonth(1);
      return date;
    } else if ("QUARTER".equals(fmtInput) || "Q".equals(fmtInput)) {
      int month = date.getMonth() - 1;
      int quarter = month / 3;
      int monthToSet = quarter * 3 + 1;
      date.setMonth(monthToSet);
      date.setDayOfMonth(1);
      return date;
    } else if ("YEAR".equals(fmtInput) || "YYYY".equals(fmtInput) || "YY".equals(fmtInput)) {
      date.setMonth(1);
      date.setDayOfMonth(1);
      return date;
    } else {
      return null;
    }
  }

  protected HiveDecimal trunc(HiveDecimal input, int scale) {
    BigDecimal bigDecimal = trunc(input.bigDecimalValue(), scale);
    return HiveDecimal.create(bigDecimal);
  }

  protected long trunc(long input, int scale) {
    return trunc(BigDecimal.valueOf(input), scale).longValue();
  }

  protected double trunc(double input, int scale) {
    return trunc(BigDecimal.valueOf(input), scale).doubleValue();
  }

  protected DoubleWritable trunc(DoubleWritable input, int scale) {
    BigDecimal bigDecimal = new BigDecimal(input.get());
    BigDecimal trunc = trunc(bigDecimal, scale);
    DoubleWritable doubleWritable = new DoubleWritable(trunc.doubleValue());
    return doubleWritable;
  }

  protected BigDecimal trunc(BigDecimal input, int scale) {
    BigDecimal output = new BigDecimal(0);
    BigDecimal pow = BigDecimal.valueOf(Math.pow(10, Math.abs(scale)));
    if (scale >= 0) {
      pow = BigDecimal.valueOf(Math.pow(10, scale));
      if (scale != 0) {
        long longValue = input.multiply(pow).longValue();
        output = BigDecimal.valueOf(longValue).divide(pow);
      } else {
        output = BigDecimal.valueOf(input.longValue());
      }
    } else {
      long longValue2 = input.divide(pow).longValue();
      output = BigDecimal.valueOf(longValue2).multiply(pow);
    }
    return output;
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy