All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aliyun.odps.io.WritableUtils Maven / Gradle / Ivy

The newest version!
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.aliyun.odps.io;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

import com.aliyun.odps.conf.Configuration;
import com.aliyun.odps.utils.ReflectionUtils;

/**
 * {@link Writable} 工具类.
 */
public final class WritableUtils {

  /**
   * 读取压缩的二进制输入流
   *
   * @param in
   *     二进制输入流
   * @return 读取的二进制数据
   * @throws IOException
   */
  public static byte[] readCompressedByteArray(DataInput in) throws IOException {
    int length = in.readInt();
    if (length == -1) {
      return null;
    }
    byte[] buffer = new byte[length];
    in.readFully(buffer); // could/should use readFully(buffer,0,length)?
    GZIPInputStream gzi = new GZIPInputStream(new ByteArrayInputStream(buffer,
                                                                       0, buffer.length));
    byte[] outbuf = new byte[length];
    ByteArrayOutputStream bos = new ByteArrayOutputStream();
    int len;
    while ((len = gzi.read(outbuf, 0, outbuf.length)) != -1) {
      bos.write(outbuf, 0, len);
    }
    byte[] decompressed = bos.toByteArray();
    bos.close();
    gzi.close();
    return decompressed;
  }

  public static void skipCompressedByteArray(DataInput in) throws IOException {
    int length = in.readInt();
    if (length != -1) {
      skipFully(in, length);
    }
  }

  /**
   * 压缩写入byte数组到二进制输出流中
   *
   * @param out
   *     二进制输出流
   * @param s
   *     写入的byte数组
   * @throws IOException
   */
  public static int writeCompressedByteArray(DataOutput out, byte[] bytes)
      throws IOException {
    if (bytes != null) {
      ByteArrayOutputStream bos = new ByteArrayOutputStream();
      GZIPOutputStream gzout = new GZIPOutputStream(bos);
      gzout.write(bytes, 0, bytes.length);
      gzout.close();
      byte[] buffer = bos.toByteArray();
      int len = buffer.length;
      out.writeInt(len);
      out.write(buffer, 0, len);
      /* debug only! Once we have confidence, can lose this. */
      return ((bytes.length != 0) ? (100 * buffer.length) / bytes.length : 0);
    } else {
      out.writeInt(-1);
      return -1;
    }
  }

  /**
   * 从二进制输入流中读取压缩的字符串
   *
   * @param in
   *     二进制输入流
   * @return 读取的字符串
   * @throws IOException
   */
  public static String readCompressedString(DataInput in) throws IOException {
    byte[] bytes = readCompressedByteArray(in);
    if (bytes == null) {
      return null;
    }
    return new String(bytes, "UTF-8");
  }

  /**
   * 压缩写入字符串到二进制输出流中
   *
   * @param out
   *     二进制输出流
   * @param s
   *     写入的字符串
   * @throws IOException
   */
  public static int writeCompressedString(DataOutput out, String s)
      throws IOException {
    return writeCompressedByteArray(out, (s != null) ? s.getBytes("UTF-8")
                                                     : null);
  }

  /**
   * 写入字符串到二进制输出流中
   *
   * @param out
   *     二进制输出流
   * @param s
   *     写入的字符串
   * @throws IOException
   */
  public static void writeString(DataOutput out, String s) throws IOException {
    if (s != null) {
      byte[] buffer = s.getBytes("UTF-8");
      int len = buffer.length;
      out.writeInt(len);
      out.write(buffer, 0, len);
    } else {
      out.writeInt(-1);
    }
  }

  /**
   * 从二进制输入流中读取字符串
   *
   * @param in
   *     二进制输入流
   * @return 读取的字符串
   * @throws IOException
   */
  public static String readString(DataInput in) throws IOException {
    int length = in.readInt();
    if (length == -1) {
      return null;
    }
    byte[] buffer = new byte[length];
    in.readFully(buffer); // could/should use readFully(buffer,0,length)?
    return new String(buffer, "UTF-8");
  }

  /**
   * 写入字符串数组到二进制输出流中
   *
   * @param out
   *     二进制输出流
   * @param s
   *     字符串数组
   * @throws IOException
   */
  public static void writeStringArray(DataOutput out, String[] s)
      throws IOException {
    out.writeInt(s.length);
    for (int i = 0; i < s.length; i++) {
      writeString(out, s[i]);
    }
  }

  /**
   * 压缩写入字符串数组到二进制输出流中
   *
   * @param out
   *     二进制输出流
   * @param s
   *     字符串数组
   * @throws IOException
   */
  public static void writeCompressedStringArray(DataOutput out, String[] s)
      throws IOException {
    if (s == null) {
      out.writeInt(-1);
      return;
    }
    out.writeInt(s.length);
    for (int i = 0; i < s.length; i++) {
      writeCompressedString(out, s[i]);
    }
  }

  /**
   * 从二进制输入流中读取字符串数组
   *
   * @param in
   *     二进制输入流
   * @return 字符串数组
   * @throws IOException
   */
  public static String[] readStringArray(DataInput in) throws IOException {
    int len = in.readInt();
    if (len == -1) {
      return null;
    }
    String[] s = new String[len];
    for (int i = 0; i < len; i++) {
      s[i] = readString(in);
    }
    return s;
  }

  /**
   * 从二进制输入流中读取压缩的字符串数组
   *
   * @param in
   *     二进制输入流
   * @return 字符串数组
   * @throws IOException
   */
  public static String[] readCompressedStringArray(DataInput in)
      throws IOException {
    int len = in.readInt();
    if (len == -1) {
      return null;
    }
    String[] s = new String[len];
    for (int i = 0; i < len; i++) {
      s[i] = readCompressedString(in);
    }
    return s;
  }

  /**
   * 打印出byte数组的字符串形式
   *
   * @param record
   *     byte数组
   */
  public static void displayByteArray(byte[] record) {
    int i;
    for (i = 0; i < record.length - 1; i++) {
      if (i % 16 == 0) {
        System.out.println();
      }
      System.out.print(Integer.toHexString(record[i] >> 4 & 0x0F));
      System.out.print(Integer.toHexString(record[i] & 0x0F));
      System.out.print(",");
    }
    System.out.print(Integer.toHexString(record[i] >> 4 & 0x0F));
    System.out.print(Integer.toHexString(record[i] & 0x0F));
    System.out.println();
  }


  /**
   * vint方式序列化一个int到二进制输出流中
   *
   * @param stream
   *     二进制输出流
   * @param i
   *     int的值
   * @throws java.io.IOException
   */
  public static void writeVInt(DataOutput stream, int i) throws IOException {
    writeVLong(stream, i);
  }

  /**
   * vlong方式序列化一个long到二进制输出流中
   *
   * @param stream
   *     二进制输出流
   * @param i
   *     long的值
   * @throws java.io.IOException
   */
  public static void writeVLong(DataOutput stream, long i) throws IOException {
    if (i >= -112 && i <= 127) {
      stream.writeByte((byte) i);
      return;
    }

    int len = -112;
    if (i < 0) {
      i ^= -1L; // take one's complement'
      len = -120;
    }

    long tmp = i;
    while (tmp != 0) {
      tmp = tmp >> 8;
      len--;
    }

    stream.writeByte((byte) len);

    len = (len < -120) ? -(len + 120) : -(len + 112);

    for (int idx = len; idx != 0; idx--) {
      int shiftbits = (idx - 1) * 8;
      long mask = 0xFFL << shiftbits;
      stream.writeByte((byte) ((i & mask) >> shiftbits));
    }
  }

  /**
   * 从二进制流中读取vlong.
   *
   * @param stream
   *     输入二进制流
   * @return vlong值
   * @throws java.io.IOException
   */
  public static long readVLong(DataInput stream) throws IOException {
    byte firstByte = stream.readByte();
    int len = decodeVIntSize(firstByte);
    if (len == 1) {
      return firstByte;
    }
    long i = 0;
    for (int idx = 0; idx < len - 1; idx++) {
      byte b = stream.readByte();
      i = i << 8;
      i = i | (b & 0xFF);
    }
    return (isNegativeVInt(firstByte) ? (i ^ -1L) : i);
  }

  /**
   * 从二进制流中读取vint.
   *
   * @param stream
   *     输入二进制流
   * @return vint值
   * @throws java.io.IOException
   */
  public static int readVInt(DataInput stream) throws IOException {
    return (int) readVLong(stream);
  }

  /**
   * 解析vint、vlong的首字节判断其是否为负值
   *
   * @param value
   *     vint、vlong的首字节
   * @return 是否为负值
   */
  public static boolean isNegativeVInt(byte value) {
    return value < -120 || (value >= -112 && value < 0);
  }

  /**
   * 解析vint、vlong的首字节(表示其本身的长度)
   *
   * @param value
   *     vint、vlong的首字节
   * @return 字节长度(1 - 9)
   */
  public static int decodeVIntSize(byte value) {
    if (value >= -112) {
      return 1;
    } else if (value < -120) {
      return -119 - value;
    }
    return -111 - value;
  }

  /**
   * 获取变长的int长度
   *
   * @param i
   *     int值
   * @return 变长的int长度
   */
  public static int getVIntSize(long i) {
    if (i >= -112 && i <= 127) {
      return 1;
    }

    if (i < 0) {
      i ^= -1L; // take one's complement'
    }
    // find the number of bytes with non-leading zeros
    int dataBits = Long.SIZE - Long.numberOfLeadingZeros(i);
    // find the number of data bytes + length byte
    return (dataBits + 7) / 8 + 1;
  }

  /**
   * 从输入流中读取枚举值
   *
   * @param 
   *     Enum 枚举类型
   * @param in
   *     输入流
   * @param enumType
   *     枚举类
   * @return Enum 读取枚举值
   * @throws IOException
   */
  public static > T readEnum(DataInput in, Class enumType)
      throws IOException {
    return T.valueOf(enumType, Text.readString(in));
  }

  /**
   * 写入枚举到输出流
   *
   * @param out
   *     输出流
   * @param enumVal
   *     枚举值
   * @throws IOException
   */
  public static void writeEnum(DataOutput out, Enum enumVal)
      throws IOException {
    Text.writeString(out, enumVal.name());
  }

  /**
   * 输入流中跳过指定长度的字节
   *
   * @param in
   *     输入流
   * @param len
   *     指定字节长度
   * @throws IOException
   */
  public static void skipFully(DataInput in, int len) throws IOException {
    int total = 0;
    int cur = 0;

    while ((total < len) && ((cur = in.skipBytes(len - total)) > 0)) {
      total += cur;
    }

    if (total < len) {
      throw new IOException("Not able to skip " + len + " bytes, possibly "
                            + "due to end of input.");
    }
  }

  /**
   * 将Writable对象转化成byte数组
   *
   * @param writables
   *     Writable对象集合
   * @return 转化后二进制数据
   */
  public static byte[] toByteArray(Writable... writables) {
    final DataOutputBuffer out = new DataOutputBuffer();
    try {
      for (Writable w : writables) {
        w.write(out);
      }
      out.close();
    } catch (IOException e) {
      throw new RuntimeException("Fail to convert writables to a byte array", e);
    }
    return out.getData();
  }

  /**
   * 通过Writable序列化方式拷贝一个Writable对象.
   *
   * @param orig
   *     拷贝源对象
   * @return 拷贝对象
   * @throws IOException
   */
  public static  T clone(T orig, Configuration conf) {
    if (orig == null) {
      return null;
    }
    DataOutputBuffer out = new DataOutputBuffer();
    DataInputBuffer in = new DataInputBuffer();
    try {
      @SuppressWarnings("unchecked")
      // Unchecked cast from Class to Class
          T newInst = ReflectionUtils.newInstance((Class) orig.getClass(), conf);
      orig.write(out);
      in.reset(out.getData(), out.getLength());
      newInst.readFields(in);
      return newInst;
    } catch (IOException e) {
      throw new RuntimeException("Error writing/reading clone buffer", e);
    } finally {
      try {
        out.close();
        in.close();
      } catch (IOException e) {
        throw new RuntimeException("Error writing/reading clone buffer", e);
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy