All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.thrift.TNonblockingMultiFetchClient Maven / Gradle / Ivy

There is a newer version: 1.0.0-beta2
Show newest version
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.thrift;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.FutureTask;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;


/**
 * This class uses a single thread to set up non-blocking sockets to a set
 * of remote servers (hostname and port pairs), and sends a same request to
 * all these servers. It then fetches responses from servers.
 *
 * Parameters:
 *   int maxRecvBufBytesPerServer - an upper limit for receive buffer size
 * per server (in byte). If a response from a server exceeds this limit, the
 * client will not allocate memory or read response data for it.
 *
 *   int fetchTimeoutSeconds - time limit for fetching responses from all
 * servers (in second). After the timeout, the fetch job is stopped and
 * available responses are returned.
 *
 *   ByteBuffer requestBuf - request message that is sent to all servers.
 *
 * Output:
 *   Responses are stored in an array of ByteBuffers. Index of elements in
 * this array corresponds to index of servers in the server list. Content in
 * a ByteBuffer may be in one of the following forms:
 *   1. First 4 bytes form an integer indicating length of following data,
 * then followed by the data.
 *   2. First 4 bytes form an integer indicating length of following data,
 * then followed by nothing - this happens when the response data size
 * exceeds maxRecvBufBytesPerServer, and the client will not read any
 * response data.
 *   3. No data in the ByteBuffer - this happens when the server does not
 * return any response within fetchTimeoutSeconds.
 *
 *   In some special cases (no servers are given, fetchTimeoutSeconds less
 * than or equal to 0, requestBuf is null), the return is null.
 *
 * Note:
 *   It assumes all remote servers are TNonblockingServers and use
 * TFramedTransport.
 *
 */
public class TNonblockingMultiFetchClient {
  
  private static final Logger LOGGER = LoggerFactory.getLogger(
    TNonblockingMultiFetchClient.class.getName()
  );

  // if the size of the response msg exceeds this limit (in byte), we will
  // not read the msg
  private int maxRecvBufBytesPerServer;

  // time limit for fetching data from all servers (in second)
  private int fetchTimeoutSeconds;

  // store request that will be sent to servers  
  private ByteBuffer requestBuf;
  private ByteBuffer requestBufDuplication;

  // a list of remote servers
  private List servers;

  // store fetch results
  private TNonblockingMultiFetchStats stats;
  private ByteBuffer[] recvBuf;

  public TNonblockingMultiFetchClient(int maxRecvBufBytesPerServer,
    int fetchTimeoutSeconds, ByteBuffer requestBuf,
    List servers) {
    this.maxRecvBufBytesPerServer = maxRecvBufBytesPerServer;
    this.fetchTimeoutSeconds = fetchTimeoutSeconds;
    this.requestBuf = requestBuf;
    this.servers = servers;
      
    stats = new TNonblockingMultiFetchStats();
    recvBuf = null;
  }

  public synchronized int getMaxRecvBufBytesPerServer() {
    return maxRecvBufBytesPerServer;
  }

  public synchronized int getFetchTimeoutSeconds() {
    return fetchTimeoutSeconds;
  }

  /**
   * return a duplication of requestBuf, so that requestBuf will not
   * be modified by others.
   */
  public synchronized ByteBuffer getRequestBuf() {
    if (requestBuf == null) {
      return null;
    } else {
      if (requestBufDuplication == null) {
        requestBufDuplication = requestBuf.duplicate();
      }
      return requestBufDuplication;  
    }
  }

  public synchronized List getServerList() {
    if (servers == null) {
      return null;
    }
    return Collections.unmodifiableList(servers);
  }

  public synchronized TNonblockingMultiFetchStats getFetchStats() {
    return stats;
  }

  /**
   * main entry function for fetching from servers
   */
  public synchronized ByteBuffer[] fetch() {
    // clear previous results
    recvBuf = null;
    stats.clear();

    if (servers == null || servers.size() == 0 ||
        requestBuf == null || fetchTimeoutSeconds <= 0) {
      return recvBuf;
    }

    ExecutorService executor = Executors.newSingleThreadExecutor();
    MultiFetch multiFetch = new MultiFetch();
    FutureTask task = new FutureTask(multiFetch, null);
    executor.execute(task);
    try {
      task.get(fetchTimeoutSeconds, TimeUnit.SECONDS);
    } catch(InterruptedException ie) {
      // attempt to cancel execution of the task.
      task.cancel(true);
      LOGGER.error("interrupted during fetch: "+ie.toString());
    } catch(ExecutionException ee) {
      // attempt to cancel execution of the task.
      task.cancel(true);
      LOGGER.error("exception during fetch: "+ee.toString());
    } catch(TimeoutException te) {
      // attempt to cancel execution of the task.  
      task.cancel(true);
      LOGGER.error("timeout for fetch: "+te.toString());
    }

    executor.shutdownNow();
    multiFetch.close();
    return recvBuf;
  }

  /**
   * Private class that does real fetch job.
   * Users are not allowed to directly use this class, as its run()
   * function may run forever.
   */
  private class MultiFetch implements Runnable {
    private Selector selector;

    /**
     * main entry function for fetching.
     *
     * Server responses are stored in TNonblocingMultiFetchClient.recvBuf,
     * and fetch statistics is in TNonblockingMultiFetchClient.stats.
     *
     * Sanity check for parameters has been done in
     * TNonblockingMultiFetchClient before calling this function.
     */
    public void run() {
      long t1 = System.currentTimeMillis();

      int numTotalServers = servers.size();
      stats.setNumTotalServers(numTotalServers);

      // buffer for receiving response from servers
      recvBuf                     = new ByteBuffer[numTotalServers];
      // buffer for sending request
      ByteBuffer sendBuf[]        = new ByteBuffer[numTotalServers];
      long numBytesRead[]         = new long[numTotalServers];
      int frameSize[]             = new int[numTotalServers];
      boolean hasReadFrameSize[]  = new boolean[numTotalServers];

      try {
        selector = Selector.open();
      } catch (IOException e) {
        LOGGER.error("selector opens error: "+e.toString());
        return;
      }

      for (int i = 0; i < numTotalServers; i++) {
        // create buffer to send request to server.
        sendBuf[i] = requestBuf.duplicate();
        // create buffer to read response's frame size from server
        recvBuf[i] = ByteBuffer.allocate(4);
        stats.incTotalRecvBufBytes(4);

        InetSocketAddress server = servers.get(i);
        SocketChannel s = null;
        SelectionKey key = null;
        try {
          s = SocketChannel.open();
          s.configureBlocking(false);
          // now this method is non-blocking
          s.connect(server);
          key = s.register(selector, s.validOps());
          // attach index of the key
          key.attach(i);
        } catch (Exception e) {
          stats.incNumConnectErrorServers();  
          String err = String.format("set up socket to server %s error: %s",
            server.toString(), e.toString());
          LOGGER.error(err);
          // free resource
          if (s != null) {
            try {s.close();} catch (Exception ex) {}
          }            
          if (key != null) {
             key.cancel();
          }
        }
      }

      // wait for events
      while (stats.getNumReadCompletedServers() +
        stats.getNumConnectErrorServers() < stats.getNumTotalServers()) {
        // if the thread is interrupted (e.g., task is cancelled)  
        if (Thread.currentThread().isInterrupted()) {
          return;
        }

        try{
          selector.select();
        } catch (Exception e) {
          LOGGER.error("selector selects error: "+e.toString());
          continue;
        }

        Iterator it = selector.selectedKeys().iterator();
        while (it.hasNext()) {
          SelectionKey selKey = it.next();
          it.remove();

          // get previously attached index
          int index = (Integer)selKey.attachment();

          if (selKey.isValid() && selKey.isConnectable()) {
            // if this socket throws an exception (e.g., connection refused),
            // print error msg and skip it.
            try {
              SocketChannel sChannel = (SocketChannel)selKey.channel();
              sChannel.finishConnect();
            } catch (Exception e) {
              stats.incNumConnectErrorServers();
              String err = String.format("socket %d connects to server %s " +
                "error: %s",
                index, servers.get(index).toString(), e.toString());
              LOGGER.error(err);
            }
          }

          if (selKey.isValid() && selKey.isWritable()) {
            if (sendBuf[index].hasRemaining()) {
              // if this socket throws an exception, print error msg and
              // skip it.
              try {
                SocketChannel sChannel = (SocketChannel)selKey.channel();
                sChannel.write(sendBuf[index]);
              } catch (Exception e) {
                String err = String.format("socket %d writes to server %s " +
                  "error: %s",
                  index, servers.get(index).toString(), e.toString());
                LOGGER.error(err);
              }
            }
          }

          if (selKey.isValid() && selKey.isReadable()) {
            // if this socket throws an exception, print error msg and
            // skip it.
            try {
              SocketChannel sChannel = (SocketChannel)selKey.channel();
              int bytesRead = sChannel.read(recvBuf[index]);

              if (bytesRead > 0) {
                numBytesRead[index] += bytesRead;

                if (!hasReadFrameSize[index] &&
                    recvBuf[index].remaining()==0) {
                  // if the frame size has been read completely, then prepare
                  // to read the actual frame.
                  frameSize[index] = recvBuf[index].getInt(0);

                  if (frameSize[index] <= 0) {
                    stats.incNumInvalidFrameSize();
                    String err = String.format("Read an invalid frame size %d"
                      + " from %s. Does the server use TFramedTransport? ",
                      frameSize[index], servers.get(index).toString());
                    LOGGER.error(err);
                    sChannel.close();
                    continue;
                  }

                  if (frameSize[index] + 4 > stats.getMaxResponseBytes()) {
                    stats.setMaxResponseBytes(frameSize[index]+4);
                  }

                  if (frameSize[index] + 4 > maxRecvBufBytesPerServer) {
                    stats.incNumOverflowedRecvBuf();
                    String err = String.format("Read frame size %d from %s,"
                      + " total buffer size would exceed limit %d",
                      frameSize[index], servers.get(index).toString(),
                      maxRecvBufBytesPerServer);
                    LOGGER.error(err);                      
                    sChannel.close();
                    continue;
                  }

                  // reallocate buffer for actual frame data
                  recvBuf[index] = ByteBuffer.allocate(frameSize[index] + 4);
                  recvBuf[index].putInt(frameSize[index]);

                  stats.incTotalRecvBufBytes(frameSize[index]);
                  hasReadFrameSize[index] = true;
                }

                if (hasReadFrameSize[index] &&
                  numBytesRead[index] >= frameSize[index]+4) {
                  // has read all data
                  sChannel.close();
                  stats.incNumReadCompletedServers();
                  long t2 = System.currentTimeMillis();
                  stats.setReadTime(t2-t1);
                }
              }
            } catch (Exception e) {
              String err = String.format("socket %d reads from server %s " +
                "error: %s",
                index, servers.get(index).toString(), e.toString());
              LOGGER.error(err);
            }
          }
        }
      }
    }

    /**
     * dispose any resource allocated
     */
    public void close() {
      try {
        if (selector.isOpen()) {
          Iterator it = selector.keys().iterator();
          while (it.hasNext()) {
            SelectionKey selKey = it.next();
            SocketChannel sChannel = (SocketChannel)selKey.channel();
            sChannel.close();
          }

          selector.close();
        }
      } catch (IOException e) {
        LOGGER.error("free resource error: "+e.toString());
      }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy