All Downloads are FREE. Search and download functionalities are using the official Maven repository.

fr.iscpif.gridscale.dirac.DIRACJobService.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (C) 05/06/13 Romain Reuillon
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */

package fr.iscpif.gridscale.dirac

import java.io.{ BufferedOutputStream, FileOutputStream, BufferedInputStream, InputStream }
import java.net.{ URI, URL }
import fr.iscpif.gridscale.cache.SingleValueCache
import fr.iscpif.gridscale.tools.DefaultTimeout
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream
import org.apache.http.client.protocol.HttpClientContext
import org.apache.http.client.utils.URIBuilder
import org.apache.http.config.{ RegistryBuilder }
import org.apache.http.conn.socket.ConnectionSocketFactory
import org.apache.http.entity.mime.MultipartEntityBuilder
import org.apache.http.{ HttpHost, HttpRequest }
import org.apache.http.client.config.RequestConfig
import org.apache.http.client.methods._
import org.apache.http.impl.client.{ HttpClients }
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager
import spray.json._
import DefaultJsonProtocol._
import scala.sys.process.BasicIO
import scala.io.Source
import fr.iscpif.gridscale.jobservice._
import concurrent.duration._

trait DIRACJobService extends JobService with DefaultTimeout {

  type A = P12HTTPSAuthentication
  type J = String
  type D = DIRACJobDescription

  case class Token(token: String, expires_in: Long)

  private implicit def strToURL(s: String) = new URL(s)

  def service: String
  def group: String
  def setup = "Dirac-Production"
  def auth2Auth = service + "/oauth2/token"
  def jobs = service + "/jobs"

  def tokenExpirationMargin = 10 -> MINUTES
  def maxConnections = 20

  @transient lazy val pool = {
    val registry = RegistryBuilder.create[ConnectionSocketFactory]().register("https", credential.factory).build()
    val pool = new PoolingHttpClientConnectionManager(registry)
    pool.setMaxTotal(maxConnections)
    pool.setDefaultMaxPerRoute(maxConnections)
    pool
  }

  @transient lazy val httpContext = HttpClientContext.create

  def requestConfig = {
    RequestConfig.custom()
      .setSocketTimeout(timeout.toMillis.toInt)
      .setConnectTimeout(timeout.toMillis.toInt)
      .build()
  }

  def httpHost: HttpHost = {
    val uri = new URI(service)
    new HttpHost(uri.getHost, uri.getPort, uri.getScheme)
  }

  def requestContent[T](request: HttpRequestBase with HttpRequest)(f: InputStream ⇒ T): T = {
    val client = HttpClients.custom()
      .setConnectionManager(pool)
      .build()

    request.setConfig(requestConfig)

    def close[T <: { def close(): Unit }, R](c: T)(f: T ⇒ R) =
      try f(c)
      finally c.close

    close(client.execute(httpHost, request, httpContext)) { response ⇒
      close(response.getEntity.getContent) { is ⇒
        f(is)
      }
    }
  }

  def request[T](request: HttpRequestBase with HttpRequest)(f: String ⇒ T): T =
    requestContent(request) { is ⇒
      f(Source.fromInputStream(is).mkString)
    }

  @transient lazy val tokenCache =
    new SingleValueCache[Token] {
      def compute() = token
      def expiresIn(t: Token) = (t.expires_in, SECONDS) - tokenExpirationMargin
    }

  def token = {
    val uri = new URIBuilder(auth2Auth)
      .setParameter("grant_type", "client_credentials")
      .setParameter("group", group)
      .setParameter("setup", setup)
      .build()

    val get = new HttpGet(uri)

    request(get) { r ⇒
      val f = r.trim.parseJson.asJsObject.getFields("token", "expires_in")
      Token(f(0).convertTo[String], f(1).convertTo[Long])
    }
  }

  def submit(jobDescription: D): String = {
    def files = {
      val builder = MultipartEntityBuilder.create()
      jobDescription.inputSandbox.zipWithIndex.foreach {
        case (f, i) ⇒ builder.addBinaryBody(f.getName, f)
      }
      builder.build
    }

    val uri = new URIBuilder(jobs)
      .setParameter("access_token", tokenCache().token)
      .setParameter("manifest", jobDescription.toJSON)
      .build

    val post = new HttpPost(uri)
    post.setEntity(files)
    request(post) { r ⇒
      r.parseJson.asJsObject.getFields("jids").head.toJson.convertTo[JsArray].elements.head.toString
    }
  }

  def state(jobId: J) = {
    val uri =
      new URIBuilder(jobs + "/" + jobId)
        .setParameter("access_token", tokenCache().token)
        .build

    val get = new HttpGet(uri)

    request(get) { r ⇒
      r.parseJson.asJsObject.getFields("status").head.toJson.convertTo[String] match {
        case "Received"  ⇒ Submitted
        case "Checking"  ⇒ Submitted
        case "Staging"   ⇒ Submitted
        case "Waiting"   ⇒ Submitted
        case "Matched"   ⇒ Submitted
        case "Running"   ⇒ Running
        case "Completed" ⇒ Running
        case "Stalled"   ⇒ Running
        case "Killed"    ⇒ Failed
        case "Deleted"   ⇒ Failed
        case "Done"      ⇒ Done
        case "Failed"    ⇒ Failed
      }
    }
  }

  def downloadOutputSandbox(desc: D, jobId: J)(implicit credential: A) = {
    val outputSandboxMap = desc.outputSandbox.toMap

    val uri =
      new URIBuilder(jobs + "/" + jobId + "/outputsandbox")
        .setParameter("access_token", tokenCache().token)
        .build

    val get = new HttpGet(uri)

    requestContent(get) { str ⇒
      val is = new TarArchiveInputStream(new BufferedInputStream(str))

      Iterator.continually(is.getNextEntry).takeWhile(_ != null).
        filter { e ⇒ outputSandboxMap.contains(e.getName) }.foreach {
          e ⇒
            println(e)
            val os = new BufferedOutputStream(new FileOutputStream(outputSandboxMap(e.getName)))
            try BasicIO.transferFully(is, os)
            finally os.close
        }
    }

  }

  def cancel(jobId: J) = {
    val uri =
      new URIBuilder(jobs + "/" + jobId)
        .setParameter("access_token", tokenCache().token)
        .build

    val delete = new HttpDelete(uri)

    request(delete) { identity }
  }

  def purge(job: J) = {}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy