All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aliyun.odps.tunnel.impl.UpsertStreamImpl Maven / Gradle / Ivy

There is a newer version: 0.51.5-public
Show newest version
package com.aliyun.odps.tunnel.impl;

import static com.aliyun.odps.tunnel.HttpHeaders.HEADER_ODPS_ROUTED_SERVER;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.math.BigDecimal;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import com.aliyun.odps.Column;
import com.aliyun.odps.commons.transport.HttpStatus;
import com.aliyun.odps.OdpsType;
import com.aliyun.odps.commons.transport.Request;
import com.aliyun.odps.data.Record;
import com.aliyun.odps.tunnel.HttpHeaders;
import com.aliyun.odps.tunnel.TunnelException;
import com.aliyun.odps.tunnel.TunnelTableSchema;
import com.aliyun.odps.tunnel.hasher.DecimalHashObject;
import com.aliyun.odps.tunnel.hasher.TypeHasher;
import com.aliyun.odps.tunnel.io.Checksum;
import com.aliyun.odps.tunnel.io.CompressOption;
import com.aliyun.odps.tunnel.io.ProtobufRecordPack;
import com.aliyun.odps.tunnel.streams.UpsertStream;
import com.aliyun.odps.type.DecimalTypeInfo;
import com.aliyun.odps.type.TypeInfo;
import com.aliyun.odps.utils.FixedNettyChannelPool;

import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBufInputStream;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.timeout.ReadTimeoutException;
import io.netty.handler.timeout.ReadTimeoutHandler;

public class UpsertStreamImpl implements UpsertStream {
  // required
  private long maxBufferSize;
  private long slotBufferSize;
  private final CompressOption compressOption;
  private final URI endpoint;
  private final UpsertSessionImpl session;

  // meta
  private Map buckets;
  private List hashKeys = new ArrayList<>();
  private TunnelTableSchema schema;

  // buffer
  private final Map bucketBuffer = new HashMap<>();
  private long totalBufferSize = 0;

  // netty
  private final Bootstrap bootstrap;
  private CountDownLatch latch;
  private FixedNettyChannelPool channelPool;
  private long connectTimeout;
  private long readTimeout;

  // status
  private Status status = Status.NORMAL;

  private Listener listener = null;

  private enum Operation {
    UPSERT,
    DELETE
  }

  private enum Status {
    NORMAL,
    ERROR,
    CLOSED
  }

  public static class Builder implements UpsertStream.Builder {
    private UpsertSessionImpl session;
    private long maxBufferSize = 64 * 1024 * 1024;
    private long slotBufferSize = 1024 * 1024;
    private CompressOption
        compressOption =
        new CompressOption(CompressOption.CompressAlgorithm.ODPS_LZ4_FRAME, -1, 0);
    private Listener listener = null;

    public Builder setSession(UpsertSessionImpl session) {
      this.session = session;
      return this;
    }

    public UpsertSessionImpl getSession() {
      return session;
    }

    public long getMaxBufferSize() {
      return maxBufferSize;
    }

    public Builder setMaxBufferSize(long maxBufferSize) {
      this.maxBufferSize = maxBufferSize;
      return this;
    }

    public long getSlotBufferSize() {
      return slotBufferSize;
    }

    public Builder setSlotBufferSize(long slotBufferSize) {
      this.slotBufferSize = slotBufferSize;
      return this;
    }

    @Override
    public CompressOption getCompressOption() {
      return compressOption;
    }

    @Override
    public Builder setCompressOption(CompressOption compressOption) {
      this.compressOption = compressOption;
      return this;
    }

    @Override
    public Listener getListener() {
      return listener;
    }

    @Override
    public Builder setListener(Listener listener) {
      this.listener = listener;
      return this;
    }

    @Override
    public UpsertStream build() throws IOException, TunnelException {
      return new UpsertStreamImpl(this);
    }
  }

  public UpsertStreamImpl(Builder builder) throws IOException, TunnelException {
    this.compressOption = builder.getCompressOption();
    this.slotBufferSize = builder.getSlotBufferSize();
    this.maxBufferSize = builder.getMaxBufferSize();
    this.session = builder.session;
    this.endpoint = session.getEndpoint();
    this.buckets = session.getBuckets();
    this.schema = session.getRecordSchema();
    this.hashKeys = session.getHashKeys();
    this.bootstrap = session.getBootstrap();
    this.channelPool = session.getChannelPool();
    this.connectTimeout = session.getConnectTimeout();
    this.readTimeout = session.getReadTimeout();
    this.listener = builder.getListener();

    newBucketBuffer();
  }

  private void newBucketBuffer() throws IOException {
    for (Integer slot : this.buckets.keySet()) {
      this.bucketBuffer.put(slot, new ProtobufRecordPack(this.schema, new Checksum(), 0, compressOption));
    }
  }
  @Override
  public void upsert(Record record) throws IOException, TunnelException {
    write(record, UpsertStreamImpl.Operation.UPSERT, null);
  }

  @Override
  public void upsert(Record record, List upsertCols) throws IOException, TunnelException {
    if (upsertCols != null && !upsertCols.isEmpty() && !session.supportPartialUpdate()) {
      throw new TunnelException(
          "Table " + session.tableName
          + " do not support partial update, consider set table properties 'acid.partial.fields.update.enable=true'");
    }
    if (upsertCols != null && !upsertCols.isEmpty()) {
      Set columnSet =
          schema.getColumns().stream().map(Column::getName).collect(Collectors.toSet());
      upsertCols.forEach((col) -> {
        if (!columnSet.contains(col)) {
          throw new IllegalArgumentException("Invalid column name:" + col);
        }
      });
    }
    write(record, UpsertStreamImpl.Operation.UPSERT, upsertCols);
  }

  @Override
  public void delete(Record record) throws IOException, TunnelException {
    write(record, UpsertStreamImpl.Operation.DELETE, null);
  }

  @Override
  public void flush() throws IOException, TunnelException {
    flush(true);
  }

  @Override
  public void close() throws IOException {
    if (status == Status.NORMAL) {
        try {
            flush();
            status = Status.CLOSED;
        } catch (TunnelException e) {
            throw new IOException(e);
        }
    }
  }

  @Override
  public void reset() throws IOException {
    if (!bucketBuffer.isEmpty()) {
      for (ProtobufRecordPack pack : bucketBuffer.values()) {
        pack.reset();
      }
    }

    totalBufferSize = 0;
    status = Status.NORMAL;
  }

  private void write(Record record, UpsertStreamImpl.Operation op, List valueColumns)
          throws TunnelException, IOException {
    checkStatus();

    List hashValues = new ArrayList<>();
    for (int key : hashKeys) {
      Object value = record.get(key);
      if (value == null) {
        throw new TunnelException(
            " UpsertRecord must have primary key value, consider provide values for column '"
            + schema.getColumn(key).getName() + "'");
      }
      TypeInfo typeInfo = schema.getColumn(key).getTypeInfo();

      // java type BigDecimal's precision and scale may be different from typeInfo,
      // so here converted to DecimalHashObject
      if (typeInfo.getOdpsType() == OdpsType.DECIMAL) {
        DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) typeInfo;
        value =
            new DecimalHashObject((BigDecimal) value, decimalTypeInfo.getPrecision(),
                                  decimalTypeInfo.getScale());
      }
      hashValues.add(TypeHasher.hash(typeInfo.getOdpsType(), value, session.getHasher()));
    }

    int bucket = TypeHasher.CombineHashVal(hashValues) % buckets.size();

    if (!bucketBuffer.containsKey(bucket)) {
      throw new TunnelException(
              "Tunnel internal error! Do not have bucket for hash key " + bucket);
    }

    ProtobufRecordPack pack = bucketBuffer.get(bucket);
    UpsertRecord r = (UpsertRecord) record;
    r.setOperation(op == UpsertStreamImpl.Operation.UPSERT ? (byte)'U' : (byte)'D');
    ArrayList valueCols = new ArrayList<>();
    if (valueColumns != null) {
      for (String validColumnName : valueColumns) {
        valueCols.add(this.schema.getColumnId(validColumnName));
      }
    }
    r.setValueCols(valueCols);
    long bytes = pack.getTotalBytes();
    pack.append(r.getRecord());
    bytes = pack.getTotalBytes() - bytes;
    totalBufferSize += bytes;
    if (pack.getTotalBytes() > slotBufferSize) {
      flush(false);
    } else if (totalBufferSize > maxBufferSize) {
      flush(true);
    }
  }
  private void flush(boolean flushAll) throws TunnelException, IOException {
    List handlers = new ArrayList<>();
    boolean success;
    int retry = 0;

    // update slot map
    Map bucketMap = session.getBuckets();
    if (bucketMap.size() != buckets.size()) {
      throw new TunnelException("session slot map is changed");
    } else {
      buckets = bucketMap;
    }

    do {
      success = true;
      handlers.clear();
      Channel channel = null;
      try {
        checkStatus();
        latch = new CountDownLatch(bucketBuffer.size());
        for (Map.Entry entry : bucketBuffer.entrySet()) {
          ProtobufRecordPack pack = entry.getValue();
          if (pack.getSize() > 0) {
            if (pack.getTotalBytes() > slotBufferSize || flushAll) {
              int bucketId = entry.getKey();
              long bytes = pack.getTotalBytes();
              pack.checkTransConsistency(false);
              pack.complete();
              bytes = pack.getTotalBytes() - bytes;
              if (!flushAll) {
                totalBufferSize += bytes;
              }
              Request request = session.buildRequest("PUT", bucketId, buckets.get(bucketId), pack.getTotalBytes(), pack.getSize(), compressOption);
              channel = channelPool.acquire();
              FlushResultHandler handler = new FlushResultHandler(pack, latch, listener, retry, bucketId);
              channel.pipeline().addLast(handler);
              handlers.add(handler);
              ChannelFuture
                  channelFuture =
                  channel.writeAndFlush(buildFullHttpRequest(request, pack.getProtobufStream()));
              channelFuture.addListener((ChannelFutureListener) future -> {
                if (!future.isSuccess()) {
                  latch.countDown();
                  channelPool.release(future.channel());
                  handler.setException(
                      new TunnelException("Connect : " + future.cause().getMessage(),
                                          future.cause()));
                  future.channel().close();
                } else {
                  future.channel().pipeline().addFirst(new ReadTimeoutHandler(readTimeout, TimeUnit.MILLISECONDS));
                }
              });
            } else {
              latch.countDown();
            }
          } else {
            latch.countDown();
          }
        }
        latch.await();
      } catch (InterruptedException e) {
        throw new TunnelException("flush interrupted", e);
      }

      for (FlushResultHandler handler : handlers) {
        if (handler.getException() != null) {
          success = false;
          if (listener != null) {
            if (!listener.onFlushFail(handler.getException(), retry)) {
              status = Status.ERROR;
              TunnelException e = new TunnelException(handler.getException().getErrorMsg(), handler.getException());
              e.setRequestId(handler.getException().getRequestId());
              e.setErrorCode(handler.getException().getErrorCode());
              throw e;
            }
          } else {
            TunnelException e = new TunnelException(handler.getException().getErrorMsg(), handler.getException());
            e.setRequestId(handler.getException().getRequestId());
            e.setErrorCode(handler.getException().getErrorCode());
            throw e;
          }
        } else {
          if (!flushAll) {
            totalBufferSize -= handler.getFlushResult().flushSize;
          }
        }
      }
      ++retry;
    } while (!success);
    if (flushAll) {
      totalBufferSize = 0;
    }
  }

  private void checkStatus() throws TunnelException {
    if (Status.CLOSED == status) {
      throw new TunnelException("Stream is closed!");
    } else if (Status.ERROR == status) {
      throw new TunnelException("Stream has error!");
    }
  }

  private HttpRequest buildFullHttpRequest(Request request, ByteArrayOutputStream content) {
    String uri = request.getURI().toString().replace(endpoint.toString(), "");
    HttpRequest req = new DefaultFullHttpRequest(
            HttpVersion.HTTP_1_1, HttpMethod.PUT, uri, Unpooled.wrappedBuffer(content.toByteArray()));
    request.getHeaders().forEach((key, value) -> req.headers().set(key, value));
    req.headers().set(HttpHeaderNames.HOST, request.getURI().getHost());
    return req;
  }

  private class FlushResultHandler extends ChannelInboundHandlerAdapter {

    private UpsertStream.FlushResult flushResult = new UpsertStream.FlushResult();
    private ProtobufRecordPack pack;
    private TunnelException exception = null;
    CountDownLatch latch;
    long start;
    Listener listener;
    int retry;
    int bucketId;

    public UpsertStream.FlushResult getFlushResult() {
      return flushResult;
    }

    public TunnelException getException() {
      return exception;
    }

    public void setException(TunnelException exception) {
      this.exception = exception;
    }

    FlushResultHandler(ProtobufRecordPack pack, CountDownLatch latch, Listener listener, int retry, int bucketId) {
      this.flushResult.recordCount = pack.getSize();
      this.pack = pack;
      this.flushResult.flushSize = pack.getTotalBytes();
      this.latch = latch;
      this.start = System.currentTimeMillis();
      this.listener = listener;
      this.retry = retry;
      this.bucketId = bucketId;
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
      FullHttpResponse response = null;
      try {
        response = (FullHttpResponse) msg;
        this.flushResult.traceId = response.headers().get(HttpHeaders.HEADER_ODPS_REQUEST_ID);
        if (response.status().equals(HttpResponseStatus.OK)) {
          this.flushResult.flushTime = System.currentTimeMillis() - start;
          pack.reset();
          if (listener != null) {
            try {
              listener.onFlush(flushResult);
            } catch (Exception ignore) {
            }
          }
        } else {
          try (ByteBufInputStream contentStream = new ByteBufInputStream(response.content())) {
            exception = new TunnelException(this.flushResult.traceId, contentStream, response.status().code());

            // 308 means should update slot map and retry
            if (response.status().code() == HttpStatus.SLOT_REASSIGNMENT) {
              if (response.headers().contains(HEADER_ODPS_ROUTED_SERVER)) {
                String newSlotServer = response.headers().get(HEADER_ODPS_ROUTED_SERVER);
                session.updateBuckets(bucketId, newSlotServer);
              } else {
                session.updateBuckets(bucketId, null);
              }
              buckets = session.getBuckets();
            }
          }
        }
      } catch (Exception e) {
        exception = new TunnelException(e.getMessage(), e);
      } finally {
        latch.countDown();
        if (response != null) {
          response.release();
        }
        channelPool.release(ctx.channel());
        ctx.close();
      }
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
      if (cause instanceof ReadTimeoutException) {
        exception = new TunnelException("Flush time out, cannot get response from server");
      } else {
        exception = new TunnelException(cause.getMessage(), cause);
      }
      latch.countDown();
      channelPool.release(ctx.channel());
      ctx.close();
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy