All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.solr.client.solrj.cloud.SocketProxy Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.solr.client.solrj.cloud;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.invoke.MethodHandles;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSocketFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Kindly borrowed the idea and base implementation from the ActiveMQ project; useful for blocking
 * traffic on a specified port.
 */
public class SocketProxy {

  private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

  public static final int ACCEPT_TIMEOUT_MILLIS = 100;

  // should be as large as the HttpShardHandlerFactory socket timeout ... or larger?
  public static final int PUMP_SOCKET_TIMEOUT_MS = 100 * 1000;

  private URI proxyUrl;
  private URI target;

  private Acceptor acceptor;
  private ServerSocket serverSocket;

  private CountDownLatch closed = new CountDownLatch(1);

  public List connections = new ArrayList<>();

  private final int listenPort;

  private int receiveBufferSize = -1;

  private boolean pauseAtStart = false;

  private int acceptBacklog = 50;

  private boolean usesSSL;

  public SocketProxy() throws Exception {
    this(0, false);
  }

  public SocketProxy(boolean useSSL) throws Exception {
    this(0, useSSL);
  }

  public SocketProxy(int port, boolean useSSL) throws Exception {
    int listenPort = port;
    this.usesSSL = useSSL;
    serverSocket = createServerSocket(useSSL);
    serverSocket.setReuseAddress(true);
    if (receiveBufferSize > 0) {
      serverSocket.setReceiveBufferSize(receiveBufferSize);
    }
    serverSocket.bind(new InetSocketAddress(listenPort), acceptBacklog);
    this.listenPort = serverSocket.getLocalPort();
  }

  public void open(URI uri) throws Exception {
    target = uri;
    proxyUrl = urlFromSocket(target, serverSocket);
    doOpen();
  }

  @Override
  public String toString() {
    return "SocketyProxy: port=" + listenPort + "; target=" + target;
  }

  public void setReceiveBufferSize(int receiveBufferSize) {
    this.receiveBufferSize = receiveBufferSize;
  }

  public void setTarget(URI tcpBrokerUri) {
    target = tcpBrokerUri;
  }

  private void doOpen() throws Exception {

    acceptor = new Acceptor(serverSocket, target);
    if (pauseAtStart) {
      acceptor.pause();
    }
    new Thread(null, acceptor, "SocketProxy-Acceptor-" + serverSocket.getLocalPort()).start();
    closed = new CountDownLatch(1);
  }

  public int getListenPort() {
    return listenPort;
  }

  private ServerSocket createServerSocket(boolean useSSL) throws Exception {
    if (useSSL) {
      return SSLServerSocketFactory.getDefault().createServerSocket();
    }
    return new ServerSocket();
  }

  private Socket createSocket(boolean useSSL) throws Exception {
    if (useSSL) {
      return SSLSocketFactory.getDefault().createSocket();
    }
    return new Socket();
  }

  public URI getUrl() {
    return proxyUrl;
  }

  /*
   * close all proxy connections and acceptor
   */
  public void close() {
    List connections;
    synchronized (this.connections) {
      connections = new ArrayList<>(this.connections);
    }
    log.warn("Closing {} connections to: {}, target: {}", connections.size(), getUrl(), target);
    for (Bridge con : connections) {
      closeConnection(con);
    }
    acceptor.close();
    closed.countDown();
  }

  /*
   * close all proxy receive connections, leaving acceptor open
   */
  public void halfClose() {
    List connections;
    synchronized (this.connections) {
      connections = new ArrayList<>(this.connections);
    }
    if (log.isInfoEnabled()) {
      log.info("halfClose, numConnections= {}", connections.size());
    }
    for (Bridge con : connections) {
      halfCloseConnection(con);
    }
  }

  public boolean waitUntilClosed(long timeoutSeconds) throws InterruptedException {
    return closed.await(timeoutSeconds, TimeUnit.SECONDS);
  }

  /*
   * called after a close to restart the acceptor on the same port
   */
  public void reopen() {
    if (log.isInfoEnabled()) {
      log.info("Re-opening connectivity to {}", getUrl());
    }
    try {
      if (proxyUrl == null) {
        throw new IllegalStateException("Can not call open before open(URI uri).");
      }
      serverSocket = createServerSocket(usesSSL);
      serverSocket.setReuseAddress(true);
      if (receiveBufferSize > 0) {
        serverSocket.setReceiveBufferSize(receiveBufferSize);
      }
      serverSocket.bind(new InetSocketAddress(proxyUrl.getPort()));
      doOpen();
    } catch (Exception e) {
      if (log.isDebugEnabled()) {
        log.debug("exception on reopen url:{} ", getUrl(), e);
      }
    }
  }

  /*
   * pause accepting new connections and data transfer through existing proxy
   * connections. All sockets remain open
   */
  public void pause() {
    synchronized (connections) {
      if (log.isInfoEnabled()) {
        log.info("pause, numConnections={}", connections.size());
      }
      acceptor.pause();
      for (Bridge con : connections) {
        con.pause();
      }
    }
  }

  /*
   * continue after pause
   */
  public void goOn() {
    synchronized (connections) {
      if (log.isInfoEnabled()) {
        log.info("goOn, numConnections={}", connections.size());
      }
      for (Bridge con : connections) {
        con.goOn();
      }
    }
    acceptor.goOn();
  }

  private void closeConnection(Bridge c) {
    try {
      c.close();
    } catch (Exception e) {
      log.debug("exception on close of: {}", c, e);
    }
  }

  private void halfCloseConnection(Bridge c) {
    try {
      c.halfClose();
    } catch (Exception e) {
      log.debug("exception on half close of: {}", c, e);
    }
  }

  public boolean isPauseAtStart() {
    return pauseAtStart;
  }

  public void setPauseAtStart(boolean pauseAtStart) {
    this.pauseAtStart = pauseAtStart;
  }

  public int getAcceptBacklog() {
    return acceptBacklog;
  }

  public void setAcceptBacklog(int acceptBacklog) {
    this.acceptBacklog = acceptBacklog;
  }

  private URI urlFromSocket(URI uri, ServerSocket serverSocket) throws Exception {
    int listenPort = serverSocket.getLocalPort();

    return new URI(
        uri.getScheme(),
        uri.getUserInfo(),
        uri.getHost(),
        listenPort,
        uri.getPath(),
        uri.getQuery(),
        uri.getFragment());
  }

  public class Bridge {

    private Socket receiveSocket;
    private Socket sendSocket;
    private Pump requestThread;
    private Pump responseThread;

    public Bridge(Socket socket, URI target) throws Exception {
      receiveSocket = socket;
      sendSocket = createSocket(usesSSL);
      if (receiveBufferSize > 0) {
        sendSocket.setReceiveBufferSize(receiveBufferSize);
      }
      sendSocket.connect(new InetSocketAddress(target.getHost(), target.getPort()));
      linkWithThreads(receiveSocket, sendSocket);
      if (log.isInfoEnabled()) {
        log.info(
            "proxy connection {}, receiveBufferSize={}",
            sendSocket,
            sendSocket.getReceiveBufferSize());
      }
    }

    public void goOn() {
      responseThread.goOn();
      requestThread.goOn();
    }

    public void pause() {
      requestThread.pause();
      responseThread.pause();
    }

    public void close() throws Exception {
      synchronized (connections) {
        connections.remove(this);
      }
      receiveSocket.close();
      sendSocket.close();
    }

    public void halfClose() throws Exception {
      receiveSocket.close();
    }

    private void linkWithThreads(Socket source, Socket dest) {
      requestThread = new Pump("Request", source, dest);
      requestThread.start();
      responseThread = new Pump("Response", dest, source);
      responseThread.start();
    }

    public class Pump extends Thread {

      protected Socket src;
      private Socket destination;
      private AtomicReference pause = new AtomicReference<>();

      public Pump(String kind, Socket source, Socket dest) {
        super("SocketProxy-" + kind + "-" + source.getPort() + ":" + dest.getPort());
        src = source;
        destination = dest;
        pause.set(new CountDownLatch(0));
      }

      public void pause() {
        pause.set(new CountDownLatch(1));
      }

      public void goOn() {
        pause.get().countDown();
      }

      @Override
      public void run() {
        byte[] buf = new byte[1024];

        try {
          src.setSoTimeout(PUMP_SOCKET_TIMEOUT_MS);
        } catch (SocketException e) {
          if (e.getMessage().equals("Socket is closed")) {
            log.warn("Failed to set socket timeout on {} due to: ", src, e);
            return;
          }
          log.error("Failed to set socket timeout on {} due to ", src, e);
          throw new RuntimeException(e);
        }

        InputStream in = null;
        OutputStream out = null;
        try {
          in = src.getInputStream();
          out = destination.getOutputStream();
          while (true) {
            int len = -1;
            try {
              len = in.read(buf);
            } catch (SocketTimeoutException ste) {
              log.warn("Error when reading from {}", src, ste);
            }

            if (len == -1) {
              log.debug("read eof from: {}", src);
              break;
            }
            pause.get().await();
            if (len > 0) out.write(buf, 0, len);
          }
        } catch (Exception e) {
          if (log.isDebugEnabled()) {
            log.debug("read/write failed, reason: {}", e.getLocalizedMessage());
          }
          try {
            if (!receiveSocket.isClosed()) {
              // for halfClose, on read/write failure if we close the
              // remote end will see a close at the same time.
              close();
            }
          } catch (Exception ignore) {
          }
        } finally {
          if (in != null) {
            try {
              in.close();
            } catch (Exception exc) {
              log.debug("Error when closing InputStream on socket: {}", src, exc);
            }
          }
          if (out != null) {
            try {
              out.close();
            } catch (Exception exc) {
              log.debug("{} when closing OutputStream on socket: {}", exc, destination);
            }
          }
        }
      }
    }
  }

  public class Acceptor implements Runnable {

    private ServerSocket socket;
    private URI target;
    private AtomicReference pause = new AtomicReference<>();

    public Acceptor(ServerSocket serverSocket, URI uri) throws SocketException {
      socket = serverSocket;
      target = uri;
      pause.set(new CountDownLatch(0));
      socket.setSoTimeout(ACCEPT_TIMEOUT_MILLIS);
    }

    public void pause() {
      pause.set(new CountDownLatch(1));
    }

    public void goOn() {
      pause.get().countDown();
    }

    @Override
    public void run() {
      try {
        while (!socket.isClosed()) {
          pause.get().await();
          try {
            Socket source = socket.accept();
            pause.get().await();
            if (receiveBufferSize > 0) {
              source.setReceiveBufferSize(receiveBufferSize);
            }
            if (log.isInfoEnabled()) {
              log.info("accepted {}, receiveBufferSize: {}", source, source.getReceiveBufferSize());
            }
            synchronized (connections) {
              connections.add(new Bridge(source, target));
            }
          } catch (SocketTimeoutException expected) {
          }
        }
      } catch (Exception e) {
        if (log.isDebugEnabled()) {
          log.debug("acceptor: finished for reason: {}", e.getLocalizedMessage());
        }
      }
    }

    public void close() {
      try {
        socket.close();
        closed.countDown();
        goOn();
      } catch (IOException ignored) {
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy