All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.io.http.PortForwarding.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.io.http

import java.io.File
import java.net.URI

import com.jcraft.jsch.{JSch, Session}
import org.apache.commons.io.IOUtils

object PortForwarding {

  lazy val Jsch = new JSch()

  def forwardPortToRemote(username: String,
                          sshHost: String,
                          sshPort: Int,
                          bindAddress: String,
                          remotePortStart: Int,
                          localHost: String,
                          localPort: Int,
                          keyDir: Option[String],
                          keySas: Option[String],
                          maxRetries: Int,
                          timeout: Int
                         ): (Session, Int) = {
    keyDir.foreach(kd =>
      new File(kd).listFiles().foreach(f =>
        try {
          Jsch.addIdentity(f.getAbsolutePath)
        } catch {
          case e: com.jcraft.jsch.JSchException =>
          case e: Exception => throw e
        }
      )
    )

    keySas.foreach { ks =>
      val privateKeyBytes = IOUtils.toByteArray(new URI(ks))
      Jsch.addIdentity("forwardingKey", privateKeyBytes, null, null) //scalastyle:ignore null
    }

    val session = Jsch.getSession(username, sshHost, sshPort)
    session.setConfig("StrictHostKeyChecking", "no")
    session.setTimeout(timeout)
    session.connect()
    var attempt = 0
    var foundPort: Option[Int] = None
    while (foundPort.isEmpty && attempt <= maxRetries) {
      try {
        session.setPortForwardingR(
          bindAddress, remotePortStart + attempt, localHost, localPort)
        foundPort = Some(remotePortStart + attempt)
      } catch {
        case e: Exception =>
          println(s"failed to forward port. Attempt: $attempt")
          attempt += 1
      }
    }
    if (foundPort.isEmpty) {
      throw new RuntimeException(s"Could not find open port between " +
        s"$remotePortStart and ${remotePortStart + maxRetries}")
    }
    println(s"forwarding to ${foundPort.get}")
    (session, foundPort.get)
  }

  def forwardPortToRemote(options: Map[String, String]): (Session, Int) = {
    forwardPortToRemote(
      options("forwarding.username"),
      options("forwarding.sshhost"),
      options.getOrElse("forwarding.sshport", "22").toInt,
      options.getOrElse("forwarding.bindaddress", "*"),
      options.get("forwarding.remoteportstart")
        .orElse(options.get("forwarding.localport")).get.toInt,
      options.getOrElse("forwarding.localhost", "0.0.0.0"),
      options("forwarding.localport").toInt,
      options.get("forwarding.keydir"),
      options.get("forwarding.keysas"),
      options.getOrElse("forwarding.maxretires", "50").toInt,
      options.getOrElse("forwarding.timeout", "20000").toInt
    )
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy