All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.impossibl.postgres.protocol.v30.BindExecCommandImpl Maven / Gradle / Ivy

There is a newer version: 0.8.9
Show newest version
/**
 * Copyright (c) 2013, impossibl.com
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 *  * Redistributions of source code must retain the above copyright notice,
 *    this list of conditions and the following disclaimer.
 *  * Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *  * Neither the name of impossibl.com nor the names of its contributors may
 *    be used to endorse or promote products derived from this software
 *    without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */
package com.impossibl.postgres.protocol.v30;

import com.impossibl.postgres.mapper.Mapper;
import com.impossibl.postgres.mapper.PropertySetter;
import com.impossibl.postgres.protocol.BindExecCommand;
import com.impossibl.postgres.protocol.Notice;
import com.impossibl.postgres.protocol.ResultField;
import com.impossibl.postgres.protocol.ResultField.Format;
import com.impossibl.postgres.protocol.TransactionStatus;
import com.impossibl.postgres.system.Context;
import com.impossibl.postgres.system.SettingsContext;
import com.impossibl.postgres.types.Type;
import com.impossibl.postgres.utils.StreamingByteBuf;
import com.impossibl.postgres.utils.guava.ByteStreams;

import static com.impossibl.postgres.protocol.ServerObjectType.Portal;
import static com.impossibl.postgres.system.Settings.FIELD_VARYING_LENGTH_MAX;
import static com.impossibl.postgres.system.Settings.PARAMETER_STREAM_THRESHOLD;
import static com.impossibl.postgres.system.Settings.PARAMETER_STREAM_THRESHOLD_DEFAULT;
import static com.impossibl.postgres.utils.Factory.createInstance;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import static java.util.Arrays.asList;

import io.netty.buffer.ByteBuf;

public class BindExecCommandImpl extends CommandImpl implements BindExecCommand {

  private static final int DEFAULT_MESSAGE_SIZE = 8192;
  private static final int STREAM_MESSAGE_SIZE = 32 * 1024;

  class BindExecCommandListener extends BaseProtocolListener {

    Context context;

    public BindExecCommandListener(Context context) {
      this.context = context;
    }

    @Override
    public boolean isComplete() {
      return status != null || error != null || exception != null;
    }

    @Override
    public void bindComplete() {
    }

    @Override
    public void rowDescription(List newResultFields) {
      resultFields = newResultFields;
      resultFieldFormats = getResultFieldFormats(newResultFields);
      resultBatch.fields = newResultFields;
      resultBatch.results = !resultFields.isEmpty() ? new ArrayList<>() : null;
      resultSetters = Mapper.buildMapping(rowType, newResultFields);
    }

    @Override
    public void noData() {
      resultBatch.fields = Collections.emptyList();
      resultBatch.results = null;
      status = Status.Completed;
    }

    @Override
    public void rowData(ByteBuf buffer) throws IOException {
      try {
        int itemCount = buffer.readShort();

        Object rowInstance = createInstance(rowType, itemCount);

        for (int c = 0; c < itemCount; ++c) {

          ResultField field = resultBatch.fields.get(c);

          Type fieldType = field.typeRef.get();

          Type.Codec.Decoder decoder = fieldType.getCodec(field.format).decoder;

          Object fieldVal = decoder.decode(fieldType, field.typeLength, field.typeModifier, buffer, context);

          resultSetters.get(c).set(rowInstance, fieldVal);
        }

        @SuppressWarnings("unchecked")
        List res = (List) resultBatch.results;
        res.add(rowInstance);
      }
      finally {
        buffer.release();
      }
    }

    @Override
    public void emptyQuery() {
      resultBatch.fields = Collections.emptyList();
      resultBatch.results = null;
      status = Status.Completed;
    }

    @Override
    public synchronized void portalSuspended() {
      status = Status.Suspended;
      notifyAll();
    }

    @Override
    public synchronized void commandComplete(String command, Long rowsAffected, Long oid) {
      status = Status.Completed;
      resultBatch.command = command;
      resultBatch.rowsAffected = rowsAffected;
      resultBatch.insertedOid = oid;

      if (maxRows > 0) {
        notifyAll();
      }
    }

    @Override
    public synchronized void error(Notice error) {
      BindExecCommandImpl.this.error = error;
      notifyAll();
    }

    @Override
    public synchronized void exception(Throwable cause) {
      setException(cause);
      notifyAll();
    }

    @Override
    public void notice(Notice notice) {
      addNotice(notice);
    }

    @Override
    public synchronized void ready(TransactionStatus txStatus) {
      notifyAll();
    }

  };


  private String statementName;
  private String portalName;
  private List parameterTypes;
  private List parameterValues;
  private List resultFields;
  private Class rowType;
  private List resultSetters;
  private int maxRows;
  private int maxFieldLength;
  private Status status;
  private SettingsContext parsingContext;
  private ResultBatch resultBatch;
  private List resultFieldFormats;
  private long queryTimeout;


  public BindExecCommandImpl(String portalName, String statementName, List parameterTypes, List parameterValues, List resultFields, Class rowType) {

    this.statementName = statementName;
    this.portalName = portalName;
    this.parameterTypes = parameterTypes;
    this.parameterValues = parameterValues;
    this.resultFields = resultFields;
    this.rowType = rowType;
    this.maxRows = 0;
    this.maxFieldLength = Integer.MAX_VALUE;

    if (resultFields != null) {
      this.resultSetters = Mapper.buildMapping(rowType, resultFields);
      this.resultFieldFormats = getResultFieldFormats(resultFields);
    }
    else {
      this.resultSetters = Collections.emptyList();
      this.resultFieldFormats = Collections.emptyList();
    }

  }

  public void reset() {
    status = null;
    resultBatch = new ResultBatch();
    resultBatch.fields = resultFields;
    resultBatch.results = (resultFields != null && !resultFields.isEmpty()) ? new ArrayList<>() : null;
  }

  @Override
  public long getQueryTimeout() {
    return queryTimeout;
  }

  @Override
  public void setQueryTimeout(long queryTimeout) {
    this.queryTimeout = queryTimeout;
  }

  @Override
  public String getStatementName() {
    return statementName;
  }

  @Override
  public String getPortalName() {
    return portalName;
  }

  @Override
  public Status getStatus() {
    return status;
  }

  @Override
  public List getParameterTypes() {
    return parameterTypes;
  }

  @Override
  public void setParameterTypes(List parameterTypes) {
    this.parameterTypes = parameterTypes;
  }

  @Override
  public List getParameterValues() {
    return parameterValues;
  }

  @Override
  public void setParameterValues(List parameterValues) {
    this.parameterValues = parameterValues;
  }

  @Override
  public int getMaxRows() {
    return maxRows;
  }

  @Override
  public void setMaxRows(int maxRows) {
    this.maxRows = maxRows;
  }

  @Override
  public int getMaxFieldLength() {
    return maxFieldLength;
  }

  @Override
  public void setMaxFieldLength(int maxFieldLength) {
    this.maxFieldLength = maxFieldLength;
  }

  @Override
  public List getResultBatches() {
    return asList(resultBatch);
  }

  @Override
  public void execute(ProtocolImpl protocol) throws IOException {

    // Setup context for parsing fields with customized parameters
    //
    parsingContext = new SettingsContext(protocol.getContext());
    parsingContext.setSetting(FIELD_VARYING_LENGTH_MAX, maxFieldLength);

    BindExecCommandListener listener = new BindExecCommandListener(parsingContext);

    protocol.setListener(listener);

    ByteBuf msg = protocol.channel.alloc().buffer(DEFAULT_MESSAGE_SIZE);

    if (status != Status.Suspended) {

      if (shouldStreamBind(parsingContext, parameterValues)) {

        StreamingByteBuf bindMsg = new StreamingByteBuf(protocol.channel, STREAM_MESSAGE_SIZE);

        protocol.writeBind(bindMsg, portalName, statementName, parameterTypes, parameterValues, resultFieldFormats, true);

        bindMsg.flush();

      }
      else {

        protocol.writeBind(msg, portalName, statementName, parameterTypes, parameterValues, resultFieldFormats, true);

      }

    }

    reset();

    if (resultFields == null) {

      protocol.writeDescribe(msg, Portal, portalName);

    }

    protocol.writeExecute(msg, portalName, maxRows);

    if (maxRows > 0 && protocol.getTransactionStatus() == TransactionStatus.Idle) {
      protocol.writeFlush(msg);
    }
    else {
      protocol.writeSync(msg);
    }

    protocol.send(msg);

    enableCancelTimer(protocol, queryTimeout);

    waitFor(listener);

  }

  static boolean shouldStreamBind(Context context, List parameterValues) {

    int streamThreshold = context.getSetting(PARAMETER_STREAM_THRESHOLD, PARAMETER_STREAM_THRESHOLD_DEFAULT);
    int streamTotal = 0;

    for (Object parameterValue : parameterValues) {

      if (parameterValue instanceof ByteStreams.LimitedInputStream) {
        streamTotal += ((ByteStreams.LimitedInputStream) parameterValue).limit();
      }
      else if (parameterValue instanceof InputStream) {
        return false;
      }
    }

    return streamTotal > streamThreshold;
  }

  static List getResultFieldFormats(List resultFields) {

    List resultFieldFormats = new ArrayList<>();

    for (ResultField resultField : resultFields) {
      resultField.format = resultField.typeRef.get().getResultFormat();
      resultFieldFormats.add(resultField.format);
    }

    return resultFieldFormats;
  }

}