All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.gu.pandomainauth.service.OAuth.scala Maven / Gradle / Ivy

The newest version!
package com.gu.pandomainauth.service

import com.gu.pandomainauth.model.{AuthenticatedUser, OAuthSettings, User}
import play.api.libs.json.JsValue
import play.api.libs.ws.{WSClient, WSResponse}
import play.api.mvc.Results.Redirect
import play.api.mvc.{RequestHeader, Result}

import java.math.BigInteger
import java.security.SecureRandom
import java.util.concurrent.atomic.AtomicReference
import scala.concurrent.{ExecutionContext, Future}
import scala.language.postfixOps


class OAuthException(val message: String, val throwable: Throwable = null) extends Exception(message, throwable)

class OAuth(config: OAuthSettings, system: String, redirectUrl: String)(implicit context: ExecutionContext, ws: WSClient) {

  private val discoveryDocumentHolder: AtomicReference[Future[DiscoveryDocument]] =
    new AtomicReference[Future[DiscoveryDocument]](fetchDiscoveryDocument())

  private def fetchDiscoveryDocument(): Future[DiscoveryDocument] =
    ws.url(config.discoveryDocumentUrl).get().map(response => DiscoveryDocument.fromJson(response.json))

  private def discoveryDocument: Future[DiscoveryDocument] =
    discoveryDocumentHolder.updateAndGet(futureDiscoveryDocument =>
      if (futureDiscoveryDocument.value.exists(_.isFailure)) {
        fetchDiscoveryDocument()
      } else {
        futureDiscoveryDocument
      }
    )

  val random = new SecureRandom()

  def generateAntiForgeryToken() = new BigInteger(130, random).toString(32)

  def oAuthResponse[T](r: WSResponse)(block: JsValue => T): T = {
    r.status match {
      case errorCode if errorCode >= 400 =>
        // try to get error if we received an error doc (Google does this)
        val error = (r.json \ "error").asOpt[Error]
        error.map { e =>
          throw new OAuthException(s"Error when calling OAuth provider: ${e.message}")
        }.getOrElse {
          throw new OAuthException(s"Unknown error when calling OAuth provider [status=$errorCode, body=${r.body}]")
        }
      case normal => block(r.json)
    }
  }

  def redirectToOAuthProvider(antiForgeryToken: String, email: Option[String] = None)
                      (implicit context: ExecutionContext): Future[Result] = {
    val queryString: Map[String, Seq[String]] = Map(
      "client_id" -> Seq(config.clientId),
      "response_type" -> Seq("code"),
      "scope" -> Seq("openid email profile"),
      "redirect_uri" -> Seq(redirectUrl),
      "state" -> Seq(antiForgeryToken)
    ) ++ email.map("login_hint" -> Seq(_)) ++ config.organizationDomain.map("hd" -> Seq(_))

    discoveryDocument.map(dd => Redirect(s"${dd.authorization_endpoint}", queryString))
  }

  def validatedUserIdentity(expectedAntiForgeryToken: String)
                           (implicit request: RequestHeader, context: ExecutionContext, ws: WSClient): Future[AuthenticatedUser] = {
    if (!request.queryString.getOrElse("state", Nil).contains(expectedAntiForgeryToken)) {
      throw new IllegalArgumentException("The anti forgery token did not match")
    } else {
      discoveryDocument.flatMap { dd =>
        val code = request.queryString("code")
        ws.url(dd.token_endpoint).post {
          Map(
            "code" -> code,
            "client_id" -> Seq(config.clientId),
            "client_secret" -> Seq(config.clientSecret),
            "redirect_uri" -> Seq(redirectUrl),
            "grant_type" -> Seq("authorization_code")
          )
        }.flatMap { response =>
          oAuthResponse(response) { json =>
            val token = Token.fromJson(json)
            val jwt = token.jwt
            ws.url(dd.userinfo_endpoint)
              .withHttpHeaders("Authorization" -> s"Bearer ${token.access_token}")
              .get().map { response =>
              oAuthResponse(response) { json =>
                val userInfo = UserInfo.fromJson(json)
                AuthenticatedUser(
                  user = User(
                    userInfo.given_name,
                    userInfo.family_name,
                    jwt.claims.email.getOrElse(userInfo.email),
                    userInfo.picture
                  ),
                  authenticatingSystem = system,
                  authenticatedIn = Set(system),
                  jwt.claims.exp * 1000,
                  false
                )
              }
            }
          }
        }
      }
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy