All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.solr.client.solrj.cloud.SocketProxy Maven / Gradle / Ivy

There is a newer version: 9.8.1
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.solr.client.solrj.cloud;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.invoke.MethodHandles;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.net.URI;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSocketFactory;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Kindly borrowed the idea and base implementation from the ActiveMQ project;
 * useful for blocking traffic on a specified port.
 */
public class SocketProxy {
  
  private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
  
  public static final int ACCEPT_TIMEOUT_MILLIS = 100;

  // should be as large as the HttpShardHandlerFactory socket timeout ... or larger?
  public static final int PUMP_SOCKET_TIMEOUT_MS = 100 * 1000;
  
  private URI proxyUrl;
  private URI target;
  
  private Acceptor acceptor;
  private ServerSocket serverSocket;
  
  private CountDownLatch closed = new CountDownLatch(1);
  
  public List connections = new LinkedList();
  
  private final int listenPort;
  
  private int receiveBufferSize = -1;
  
  private boolean pauseAtStart = false;
  
  private int acceptBacklog = 50;

  private boolean usesSSL;

  public SocketProxy() throws Exception {
    this(0, false);
  }
  
  public SocketProxy( boolean useSSL) throws Exception {
    this(0, useSSL);
  }
  
  public SocketProxy(int port, boolean useSSL) throws Exception {
    int listenPort = port;
    this.usesSSL = useSSL;
    serverSocket = createServerSocket(useSSL);
    serverSocket.setReuseAddress(true);
    if (receiveBufferSize > 0) {
      serverSocket.setReceiveBufferSize(receiveBufferSize);
    }
    serverSocket.bind(new InetSocketAddress(listenPort), acceptBacklog);
    this.listenPort = serverSocket.getLocalPort();
  }
  
  public void open(URI uri) throws Exception {
    target = uri;
    proxyUrl = urlFromSocket(target, serverSocket);
    doOpen();
  }
  
  public String toString() {
    return "SocketyProxy: port="+listenPort+"; target="+target;
  }
    
  public void setReceiveBufferSize(int receiveBufferSize) {
    this.receiveBufferSize = receiveBufferSize;
  }
  
  public void setTarget(URI tcpBrokerUri) {
    target = tcpBrokerUri;
  }
  
  private void doOpen() throws Exception {
    
    acceptor = new Acceptor(serverSocket, target);
    if (pauseAtStart) {
      acceptor.pause();
    }
    new Thread(null, acceptor, "SocketProxy-Acceptor-"
        + serverSocket.getLocalPort()).start();
    closed = new CountDownLatch(1);
  }
  
  public int getListenPort() {
    return listenPort;
  }
  
  private ServerSocket createServerSocket(boolean useSSL) throws Exception {
    if (useSSL) {
      return SSLServerSocketFactory.getDefault().createServerSocket();
    }
    return new ServerSocket();
  }
  
  private Socket createSocket(boolean useSSL) throws Exception {
    if (useSSL) {
      return SSLSocketFactory.getDefault().createSocket();
    }
    return new Socket();
  }
  
  public URI getUrl() {
    return proxyUrl;
  }
  
  /*
   * close all proxy connections and acceptor
   */
  public void close() {
    List connections;
    synchronized (this.connections) {
      connections = new ArrayList(this.connections);
    }
    log.warn("Closing " + connections.size()+" connections to: "+getUrl()+", target: "+target);
    for (Bridge con : connections) {
      closeConnection(con);
    }
    acceptor.close();
    closed.countDown();
  }
  
  /*
   * close all proxy receive connections, leaving acceptor open
   */
  public void halfClose() {
    List connections;
    synchronized (this.connections) {
      connections = new ArrayList(this.connections);
    }
    log.info("halfClose, numConnections=" + connections.size());
    for (Bridge con : connections) {
      halfCloseConnection(con);
    }
  }
  
  public boolean waitUntilClosed(long timeoutSeconds)
      throws InterruptedException {
    return closed.await(timeoutSeconds, TimeUnit.SECONDS);
  }
  
  /*
   * called after a close to restart the acceptor on the same port
   */
  public void reopen() {
    log.info("Re-opening connectivity to "+getUrl());
    try {
      if (proxyUrl == null) {
        throw new IllegalStateException("Can not call open before open(URI uri).");
      }
      serverSocket = createServerSocket(usesSSL);
      serverSocket.setReuseAddress(true);
      if (receiveBufferSize > 0) {
        serverSocket.setReceiveBufferSize(receiveBufferSize);
      }
      serverSocket.bind(new InetSocketAddress(proxyUrl.getPort()));
      doOpen();
    } catch (Exception e) {
      log.debug("exception on reopen url:" + getUrl(), e);
    }
  }
  
  /*
   * pause accepting new connections and data transfer through existing proxy
   * connections. All sockets remain open
   */
  public void pause() {
    synchronized (connections) {
      log.info("pause, numConnections=" + connections.size());
      acceptor.pause();
      for (Bridge con : connections) {
        con.pause();
      }
    }
  }
  
  /*
   * continue after pause
   */
  public void goOn() {
    synchronized (connections) {
      log.info("goOn, numConnections=" + connections.size());
      for (Bridge con : connections) {
        con.goOn();
      }
    }
    acceptor.goOn();
  }
  
  private void closeConnection(Bridge c) {
    try {
      c.close();
    } catch (Exception e) {
      log.debug("exception on close of: " + c, e);
    }
  }
  
  private void halfCloseConnection(Bridge c) {
    try {
      c.halfClose();
    } catch (Exception e) {
      log.debug("exception on half close of: " + c, e);
    }
  }
  
  public boolean isPauseAtStart() {
    return pauseAtStart;
  }
  
  public void setPauseAtStart(boolean pauseAtStart) {
    this.pauseAtStart = pauseAtStart;
  }
  
  public int getAcceptBacklog() {
    return acceptBacklog;
  }
  
  public void setAcceptBacklog(int acceptBacklog) {
    this.acceptBacklog = acceptBacklog;
  }
  
  private URI urlFromSocket(URI uri, ServerSocket serverSocket)
      throws Exception {
    int listenPort = serverSocket.getLocalPort();
    
    return new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(),
        listenPort, uri.getPath(), uri.getQuery(), uri.getFragment());
  }
  
  public class Bridge {
    
    private Socket receiveSocket;
    private Socket sendSocket;
    private Pump requestThread;
    private Pump responseThread;
    
    public Bridge(Socket socket, URI target) throws Exception {
      receiveSocket = socket;
      sendSocket = createSocket(usesSSL);
      if (receiveBufferSize > 0) {
        sendSocket.setReceiveBufferSize(receiveBufferSize);
      }
      sendSocket.connect(new InetSocketAddress(target.getHost(), target
          .getPort()));
      linkWithThreads(receiveSocket, sendSocket);
      log.info("proxy connection " + sendSocket + ", receiveBufferSize="
          + sendSocket.getReceiveBufferSize());
    }
    
    public void goOn() {
      responseThread.goOn();
      requestThread.goOn();
    }
    
    public void pause() {
      requestThread.pause();
      responseThread.pause();
    }
    
    public void close() throws Exception {
      synchronized (connections) {
        connections.remove(this);
      }
      receiveSocket.close();
      sendSocket.close();
    }
    
    public void halfClose() throws Exception {
      receiveSocket.close();
    }
    
    private void linkWithThreads(Socket source, Socket dest) {
      requestThread = new Pump("Request", source, dest);
      requestThread.start();
      responseThread = new Pump("Response", dest, source);
      responseThread.start();
    }
    
    public class Pump extends Thread {
      
      protected Socket src;
      private Socket destination;
      private AtomicReference pause = new AtomicReference();
      
      public Pump(String kind, Socket source, Socket dest) {
        super("SocketProxy-"+kind+"-" + source.getPort() + ":"
            + dest.getPort());
        src = source;
        destination = dest;
        pause.set(new CountDownLatch(0));
      }
      
      public void pause() {
        pause.set(new CountDownLatch(1));
      }
      
      public void goOn() {
        pause.get().countDown();
      }
      
      public void run() {
        byte[] buf = new byte[1024];

        try {
          src.setSoTimeout(PUMP_SOCKET_TIMEOUT_MS);
        } catch (SocketException e) {
          if (e.getMessage().equals("Socket is closed")) {
            log.warn("Failed to set socket timeout on "+src+" due to: "+e);
            return;
          }
          log.error("Failed to set socket timeout on "+src+" due to: "+e);
          throw new RuntimeException(e);
        }

        InputStream in = null;
        OutputStream out = null;
        try {
          in = src.getInputStream();
          out = destination.getOutputStream();
          while (true) {
            int len = -1;
            try {
              len = in.read(buf);
            } catch (SocketTimeoutException ste) {
              log.warn(ste+" when reading from "+src);
            }

            if (len == -1) {
              log.debug("read eof from:" + src);
              break;
            }
            pause.get().await();
            if (len > 0)
              out.write(buf, 0, len);
          }
        } catch (Exception e) {
          log.debug("read/write failed, reason: " + e.getLocalizedMessage());
          try {
            if (!receiveSocket.isClosed()) {
              // for halfClose, on read/write failure if we close the
              // remote end will see a close at the same time.
              close();
            }
          } catch (Exception ignore) {}
        } finally {
          if (in != null) {
            try {
              in.close();
            } catch (Exception exc) {
              log.debug(exc+" when closing InputStream on socket: "+src);
            }
          }
          if (out != null) {
            try {
              out.close();
            } catch (Exception exc) {
              log.debug(exc+" when closing OutputStream on socket: "+destination);
            }
          }
        }
      }
    }
  }
  
  public class Acceptor implements Runnable {
    
    private ServerSocket socket;
    private URI target;
    private AtomicReference pause = new AtomicReference();
    
    public Acceptor(ServerSocket serverSocket, URI uri) {
      socket = serverSocket;
      target = uri;
      pause.set(new CountDownLatch(0));
      try {
        socket.setSoTimeout(ACCEPT_TIMEOUT_MILLIS);
      } catch (SocketException e) {
        e.printStackTrace();
      }
    }
    
    public void pause() {
      pause.set(new CountDownLatch(1));
    }
    
    public void goOn() {
      pause.get().countDown();
    }
    
    public void run() {
      try {
        while (!socket.isClosed()) {
          pause.get().await();
          try {
            Socket source = socket.accept();
            pause.get().await();
            if (receiveBufferSize > 0) {
              source.setReceiveBufferSize(receiveBufferSize);
            }
            log.info("accepted " + source + ", receiveBufferSize:"
                + source.getReceiveBufferSize());
            synchronized (connections) {
              connections.add(new Bridge(source, target));
            }
          } catch (SocketTimeoutException expected) {}
        }
      } catch (Exception e) {
        log.debug("acceptor: finished for reason: " + e.getLocalizedMessage());
      }
    }
    
    public void close() {
      try {
        socket.close();
        closed.countDown();
        goOn();
      } catch (IOException ignored) {}
    }
  }
  
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy