All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.beast.CRSServer.scala Maven / Gradle / Ivy

There is a newer version: 0.10.1-RC2
Show newest version
/*
 * Copyright 2020 University of California, Riverside
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.spark.beast

import edu.ucr.cs.bdlab.beast.util.OperationParam

import javax.servlet.http
import javax.servlet.http.{HttpServletRequest, HttpServletResponse}
import org.apache.hadoop.conf.Configuration
import org.apache.http.HttpEntity
import org.apache.http.client.HttpClient
import org.apache.http.client.methods.HttpPost
import org.apache.http.entity.ByteArrayEntity
import org.apache.http.impl.client.HttpClients
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.{SparkConf, SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.eclipse.jetty.server.{Request, Server, ServerConnector}
import org.eclipse.jetty.servlet.{ServletHandler, ServletHolder}
import org.eclipse.jetty.util.thread.QueuedThreadPool
import org.geotools.referencing.CRS
import org.opengis.referencing.crs.CoordinateReferenceSystem

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, IOException, ObjectInputStream, ObjectOutputStream}
import java.net.URL
import java.util.regex.Pattern

/**
 * A server that handles non-standard coordinate reference systems (CRS) between executor
 * nodes and the driver. It has three main methods.
 * - `setPort(port: Int)` - Sets the port where the server is listening (or should listen)
 * - `startServer` - Starts the server. Should run on the driver node before the executors start.
 * - `stopServer` - Stops the server
 * - `crsToSRID(crs: CoordinateReferenceSystem): Int` - puts the given CRS into the cache
 *     and returns a unique ID for it.
 * - `sridToCRS(srid: Int): CoordinateReferenceSystem` - Returns the CRS with the given SRID if it exists.
 */
object CRSServer extends Logging {

  @OperationParam(description = "The port on which the CRSServer is listening", defaultValue = "21345")
  val CRSServerPort = "crsserver.port"

  @OperationParam(description = "The host on which the CRSServer is running", defaultValue = "127.0.0.1")
  val CRSServerHost = "crsserver.host"

  val DefaultPort: Int = 21345

  lazy val serverPort: Int = TaskContext.get().getLocalProperty(CRSServerPort).toInt

  lazy val isServerRunning: Boolean = TaskContext.get() != null && TaskContext.get().getLocalProperty(CRSServerPort) != null

  lazy val serverAddress: String = TaskContext.get().getLocalProperty(CRSServerHost)

  /**The running Jetty server instance*/
  var server: Server = _

  /**
   * A cache (at both the client and server) to speed up [[crsToSRID()]] calls.
   * It contains pairs of (CRS, SRID) whether the CRS is retrieved from the server or from [[CRS.lookupEpsgCode()]]
   */
  private[beast] val crsCache: scala.collection.mutable.Map[CoordinateReferenceSystem, Int] =
    new scala.collection.mutable.HashMap[CoordinateReferenceSystem, Int]()

  /**
   * The reverse cache maps each SRID to a CRS to speed up [[sridToCRS()]] calls.
   */
  private[beast] val sridCache: scala.collection.mutable.Map[Int, CoordinateReferenceSystem] =
    new scala.collection.mutable.HashMap[Int, CoordinateReferenceSystem]()

  /**A pattern for HTTP requests*/
  val pattern: Pattern = Pattern.compile("(?i)/?crs(/?(-?\\d+))?/?")

  /**
   * Starts the server and returns the port on which it is listening
   * @return the port on which the server is started
   */
  def startServer(defaultPort: Int = DefaultPort): Int = {
    val numThreads: Int = Runtime.getRuntime.availableProcessors() * 2
    val threadPool = new QueuedThreadPool(numThreads * 2, numThreads)
    server = new Server(threadPool)
    val connector = new ServerConnector(server)
    server.addConnector(connector)
    var started: Boolean = false
    var attempt: Int = 0
    // Try several ports starting with the default port until an open port is found
    do {
      try {
        connector.setPort(defaultPort + attempt)
        val servletHandler = new ServletHandler
        server.setHandler(servletHandler)
        //server.setThreadPool(threadPool)
        server.start()
        // Wait until the server is running
        while (!server.isRunning)
          Thread.sleep(1000)
        // If no exception was thrown, then store this port and use it
        val port: Int = defaultPort + attempt
        logInfo(s"Successfully started CRSServer on port $port")

        // Register the servlets
        servletHandler.addServletWithMapping(new ServletHolder("insert", insertServlet), "/crs")
        servletHandler.addServletWithMapping(new ServletHolder("retrieve", retrieveServlet), "/crs/*")
        started = true
        return port
      } catch {
        case e: IOException =>
          logWarning(s"Could not start CRSServer on port ${defaultPort + attempt}", e)
          attempt += 1
      }
    } while (!started && attempt < 10)
    logWarning("Could not start CRSServer on all attempted ports. Failing ...")
    -1
  }

  /**
   * Start the CRS server and store the information in the configuration of the given [[SparkContext]]
   * @param sc the spark context to store the server information in
   * @return `true` if the server starts correctly. `false` otherwise.
   */
  def startServer(sc: SparkContext): Boolean = {
    // Start the server and get the port
    val crsServerPort = CRSServer.startServer(sc.getConf.getInt(CRSServer.CRSServerPort, DefaultPort))
    // Port -1 indicates that the server did not start correctly
    if (crsServerPort == -1)
      return false
    // Close the CRSServer on application exit
    stopOnExit(sc)
    // Set the CRSServer information as a local property to be accessed by the workers
    sc.setLocalProperty(CRSServer.CRSServerPort, crsServerPort.toString)
    sc.setLocalProperty(CRSServer.CRSServerHost, java.net.InetAddress.getLocalHost.getHostAddress)
    true
  }

  def startServer(jsc: JavaSparkContext): Boolean = startServer(jsc.sc)

  /**
   * Register a listener in the given [[SparkContext]] to stop the CRSServer on application exit.
   * @param sc the [[SparkContext]] to register the closing listener to
   */
  private def stopOnExit(sc: SparkContext): Unit = {
    // Close the CRSServer upon application end
    sc.addSparkListener(new SparkListener() {
      override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit =
        CRSServer.stopServer()
    })
  }

  def stopServer(wait: Boolean = false): Unit = {
    if (server == null)
      return
    server.stop()
    if (wait)
      server.join()
  }

  /**
   * Inserts a CRS given in the HTTP request into the cache. If the CRS is already in the cache, its key (SRID) is
   * returned in the response. If it does not exist, it is inserted in the cache with a new ID and the ID is returned.
   */
  val insertServlet: javax.servlet.Servlet = new javax.servlet.http.HttpServlet() {
    override def doPost(request: HttpServletRequest, response: HttpServletResponse): Unit = {
      logInfo("Handling insert")
      // First, parse the CRS from the request
      val length = request.getContentLength
      val crsArray = new Array[Byte](length)
      val inputStream = request.getInputStream
      var offset = 0
      while (offset < length)
        offset += inputStream.read(crsArray, offset, length - offset)
      inputStream.close()
      try {
        val in = new ObjectInputStream(new ByteArrayInputStream(crsArray))
        val crs = in.readObject().asInstanceOf[CoordinateReferenceSystem]
        logDebug(s"Handling insert of crs '$crs'")
        // Use the local cache to retrieve or assigned an SRID for this CRS
        // This whole method is synchronized to avoid multiple inserts of CRS into the cache
        var srid = 0
        crsCache.synchronized {
          val sridO = crsCache.get(crs)
          if (sridO.isDefined) srid = sridO.get else {
            // Not found, try to get an EPSG code using full search
            // Note: We do not try the quick search since it must have been tried at the client already
            val epsgCode = CRS.lookupEpsgCode(crs, true)
            if (epsgCode != null) srid = epsgCode else {
              // No EPSG code, assign it a custom negative SRID
              srid = if (crsCache.isEmpty) -1 else (crsCache.values.min min 0) - 1
            }
            assert(srid != 0, "A valid CRS should never get assigned SRID=0")
            logDebug(s"New SRID $srid assigned to CRS '${crs.toWKT}'")
            // Update the two caches
            crsCache.put(crs, srid)
            sridCache.put(srid, crs)
          }
        }
        // Return the SRID in the response
        request.asInstanceOf[Request].setHandled(true)
        response.setStatus(http.HttpServletResponse.SC_OK)
        response.setContentType("text/plain")
        val writer = response.getWriter
        writer.print(srid)
        writer.close()
      } catch {
        case e: Exception =>
          e.printStackTrace()
          response.setStatus(http.HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
          response.setContentType("text/plain")
          val writer = response.getWriter
          e.printStackTrace(writer)
          writer.close()
        }
      }
    }

  /**
   * Handles the service of retrieving a CRS for an assigned SRID.
   * This method is expected to be called only for a negative SRID that corresponds to a custom CRS.
   * Therefore, it is expected that we will always find it in the cache since it must have been previously
   * assigned at the server.
   * If found, the CRS is serialized using Java format and returned.
   * In case it is not found, a 404 response is returned.
   */
  val retrieveServlet: javax.servlet.Servlet = new javax.servlet.http.HttpServlet() {
    override def doGet(request: HttpServletRequest, response: HttpServletResponse): Unit = {
      val path = request.getPathInfo
      val srid: Int = path.substring(1).toInt
      logInfo(s"Handling retrieve of SRID: $srid")
      val crsO = sridCache.get(srid)
      if (crsO.isEmpty) {
        // Not found
        logInfo(s"SRID: $srid not found!")
        request.asInstanceOf[Request].setHandled(true)
        response.setStatus(http.HttpServletResponse.SC_NOT_FOUND)
      } else {
        // Found! Serialize to the output in a binary format
        request.asInstanceOf[Request].setHandled(true)
        response.setContentType("application/octet-stream")
        response.setStatus(http.HttpServletResponse.SC_OK)
        val responseOut = new ObjectOutputStream(response.getOutputStream)
        responseOut.writeObject(crsO.get)
        responseOut.close()
      }

    }
  }

  /**
   * Get an integer SRID that corresponds to the given CRS according to the following logic.
   *  1. If crs is null, return 0
   *  2. Search the local cache as the fastest method of known CRS.
   *  3. If not found in cache, look up the the EPSG database to find an SRID, cache, and return it.
   *  4a. If the server is running, contact the server to get the SRID
   *  4b. If the server is not running, assign a custom negative SRID and cache it
   * @param crs the CRS to find an SRID for
   * @return a unique SRID that identifies the given CRS
   */
  def crsToSRID(crs: CoordinateReferenceSystem) : Int = {
    var assignedSRID: Int = 0
    try {
      // 1. If crs is null, return 0
      if (crs == null)
        return 0
      // 2. Search the local cache as the fastest method of known CRS.
      val sridO = crsCache.get(crs)
      if (sridO.isDefined) {
        assignedSRID = sridO.get
      } else {
        // CRS not found in cache
        // 3. Try to search for an EPSG code to avoid assigning a custom ID
        val epsgCode: Integer = CRS.lookupEpsgCode(crs, false)
        // If not found even with a full scan, must assign a custom negative ID
        if (epsgCode != null) {
          assignedSRID = epsgCode
        } else if (isServerRunning) {
          // 4a. Not found, contact the server to retrieve or assign a new SRID
          val url: String = s"http://$serverAddress:$serverPort/crs"
          logInfo(s"Calling POST '$url'")
          val httpClient: HttpClient = HttpClients.createDefault()
          val httpPost: HttpPost = new HttpPost(url)
          // Serialize the CRS to send to the server
          // Note: For some rare cases, when we serialized the CRS in WKT format, the server couldn't parse it back
          val baos = new ByteArrayOutputStream()
          val out = new ObjectOutputStream(baos)
          out.writeObject(crs)
          out.close()
          httpPost.setEntity(new ByteArrayEntity(baos.toByteArray))
          val response = httpClient.execute(httpPost)
          val responseEntity: HttpEntity = response.getEntity

          if (responseEntity != null && responseEntity.getContentLength > 0) {
            val buffer = new Array[Byte](responseEntity.getContentLength.toInt)
            val inputStream = responseEntity.getContent
            var offset = 0
            while (offset < buffer.length) {
              offset += inputStream.read(buffer, offset, buffer.length - offset)
            }
            inputStream.close()
            assignedSRID = new String(buffer).toInt
          } else {
            logError(s"Did not receive a proper response for $url for CRS '$crs''")
          }
        } else {
          // 4b. Server not running, assign a custom negative SRID locally
          crsCache.synchronized {
            assignedSRID = if (crsCache.isEmpty) -1 else ((crsCache.values.min min 0) - 1)
            // Insert in both caches
            crsCache.put(crs, assignedSRID)
            sridCache.put(assignedSRID, crs)
          }
        }
        if (!sridCache.contains(assignedSRID)) {
          // Handle all cases except 4b which already updates the cache
          crsCache.synchronized {
            if (assignedSRID != 0) {
              // Insert in both caches
              crsCache.put(crs, assignedSRID)
              sridCache.put(assignedSRID, crs)
            }
          }
        }
      }
    } catch {
      case e: IOException =>
        throw new RuntimeException("s\"Did not receive a response for CRS '$crs'", e)
        //logError(s"Did not receive a response for CRS '$crs'. Returning 0 to signal error")
    }
    assignedSRID
  }

  /**
   * Convert the given SRID to CRS according to the following logic.
   *  1. If the SRID is zero, it indicates an invalid SRID and `null` is returned.
   *  2. It searches the local cache and retrieves the SRID.
   *  3a. If SRID is positive, use it as an EPSG, retrieve the CRS, cache and return it.
   *  3b. If SRID is negative, contact the server, retrieve the CRS, cache and return it.
   * @param srid the SRID that needs to be converted to a CRS
   * @return the converted CRS.
   */
  def sridToCRS(srid: Int): CoordinateReferenceSystem = {
    try {
      // 1. Zero always indicates an invalid or unset SRID that always return null
      if (srid == 0)
        return null
      logDebug(s"Finding CRS for SRID:$srid")
      var crs: CoordinateReferenceSystem = null
      // 2. Try to the local cache since it is faster than all other methods
      val crsO = sridCache.get(srid)
      if (crsO.isDefined) {
        crs = crsO.get
      } else if (srid > 0) {
        // 3a. Not found and positive, try it as an EPSG code.
        crs = CRS.decode(s"EPSG:$srid", true)
      } else {
        // 3b. Not found and negative, contact the server
        // Request from the server
        val path: String = s"http://$serverAddress:$serverPort/crs/$srid"
        val url = new URL(path)
        logInfo(s"Calling GET '$url'")
        val inputStream = url.openStream()
        val rawCRS = new ByteArrayOutputStream()
        val buffer = new Array[Byte](4096)
        var done: Boolean = false
        while (!done) {
          val bytesRead = inputStream.read(buffer)
          if (bytesRead > 0) {
            rawCRS.write(buffer, 0, bytesRead)
          } else {
            done = true
          }
        }
        inputStream.close()
        rawCRS.close()
        val parsedBack = new ObjectInputStream(new ByteArrayInputStream(rawCRS.toByteArray))
        crs = parsedBack.readObject().asInstanceOf[CoordinateReferenceSystem]
      }
      if (crs != null) {
        // Add to local caches
        crsCache.put(crs, srid)
        sridCache.put(srid, crs)
      }
      crs
    } catch {
      case e: Exception => throw new RuntimeException(s"Error retrieving the CRS of SRID $srid", e)
    }
  }

  /**
   * Implicit conversion from Hadoop Configuration to Spark Configuration to allow the methods above to accept
   * Hadoop Configuration
   * @param hadoopConf
   * @return
   */
  implicit def sparkConfFromHadoopConf(hadoopConf: Configuration): SparkConf = {
    val sparkConf = new SparkConf()
    val hadoopConfIterator = hadoopConf.iterator()
    while (hadoopConfIterator.hasNext) {
      val hadoopConfEntry = hadoopConfIterator.next()
      sparkConf.set(hadoopConfEntry.getKey, hadoopConfEntry.getValue, true)
    }
    sparkConf
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy