Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OAuth2 single sign-on implementation (BE + FE) #430

Merged
merged 6 commits into from
Jan 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ object Dependencies {

val reflections = "org.reflections" % "reflections" % "0.9.11"
val zip4j = "net.lingala.zip4j" % "zip4j" % "1.3.2"
val elastic4play = "org.cert-bdf" %% "elastic4play" % "1.4.2"
val elastic4play = "org.cert-bdf" %% "elastic4play" % "1.5.0-SNAPSHOT"
val akkaCluster = "com.typesafe.akka" %% "akka-cluster" % "2.5.6"
val akkaClusterTools = "com.typesafe.akka" %% "akka-cluster-tools" % "2.5.6"
}
}
33 changes: 26 additions & 7 deletions thehive-backend/app/controllers/AuthenticationCtrl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@ package controllers

import javax.inject.{ Inject, Singleton }

import scala.concurrent.{ ExecutionContext, Future }

import play.api.mvc._

import models.UserStatus
import services.UserSrv

import org.elastic4play.controllers.{ Authenticated, Fields, FieldsBodyParser, Renderer }
import org.elastic4play.database.DBIndex
import org.elastic4play.services.AuthSrv
import org.elastic4play.{ AuthorizationError, Timed }
import org.elastic4play.{ AuthorizationError, OAuth2Redirect, Timed }
import play.api.mvc._
import services.UserSrv

import scala.concurrent.{ ExecutionContext, Future }

@Singleton
class AuthenticationCtrl @Inject() (
Expand Down Expand Up @@ -42,6 +40,27 @@ class AuthenticationCtrl @Inject() (
}
}

@Timed
def ssoLogin: Action[AnyContent] = Action.async { implicit request ⇒
dbIndex.getIndexStatus.flatMap {
case false ⇒ Future.successful(Results.Status(520))
case _ ⇒
(for {
authContext ← authSrv.authenticate()
user ← userSrv.get(authContext.userId)
} yield {
if (user.status() == UserStatus.Ok)
authenticated.setSessingUser(Ok, authContext)
else
throw AuthorizationError("Your account is locked")
}) recover {
// A bit of a hack with the status code, so that Angular doesn't reject the origin
case OAuth2Redirect(redirectUrl, qp) ⇒ Redirect(redirectUrl, qp, status = OK)
case e ⇒ throw e
}
}
}

@Timed
def logout = Action {
Ok.withNewSession
Expand Down
8 changes: 3 additions & 5 deletions thehive-backend/app/controllers/StatusCtrl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,13 @@ import javax.inject.{ Inject, Singleton }
import scala.collection.immutable
import scala.concurrent.{ ExecutionContext, Future }
import scala.util.Try

import play.api.Configuration
import play.api.libs.json.{ JsObject, JsString, Json }
import play.api.libs.json.{ JsBoolean, JsObject, JsString, Json }
import play.api.libs.json.Json.toJsFieldJsValueWrapper
import play.api.mvc.{ AbstractController, Action, AnyContent, ControllerComponents }

import com.sksamuel.elastic4s.ElasticDsl
import connectors.Connector
import models.HealthStatus

import org.elastic4play.Timed
import org.elastic4play.database.DBIndex
import org.elastic4play.services.AuthSrv
Expand Down Expand Up @@ -51,7 +48,8 @@ class StatusCtrl @Inject() (
case multiAuthSrv: MultiAuthSrv ⇒ multiAuthSrv.authProviders.map { a ⇒ JsString(a.name) }
case _ ⇒ JsString(authSrv.name)
}),
"capabilities" → authSrv.capabilities.map(c ⇒ JsString(c.toString)))))
"capabilities" → authSrv.capabilities.map(c ⇒ JsString(c.toString)),
"ssoAutoLogin" -> JsBoolean(configuration.getOptional[Boolean]("auth.sso.autologin").getOrElse(false)))))
}
}

Expand Down
55 changes: 30 additions & 25 deletions thehive-backend/app/controllers/StreamCtrl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ import play.api.libs.json.Json.toJsFieldJsValueWrapper
import play.api.mvc._
import play.api.{ Configuration, Logger }

import akka.actor.{ ActorSystem, Props }
import akka.pattern.ask
import akka.actor.{ ActorIdentity, ActorSystem, Identify, Props }
import akka.cluster.pubsub.DistributedPubSub
import akka.cluster.pubsub.DistributedPubSubMediator.{ Put, Send }
import akka.pattern.{ AskTimeoutException, ask }
import akka.util.Timeout
import models.Roles
import services.StreamActor
Expand All @@ -28,8 +30,6 @@ import org.elastic4play.Timed
class StreamCtrl(
cacheExpiration: FiniteDuration,
refresh: FiniteDuration,
nextItemMaxWait: FiniteDuration,
globalMaxWait: FiniteDuration,
authenticated: Authenticated,
renderer: Renderer,
eventSrv: EventSrv,
Expand All @@ -52,8 +52,6 @@ class StreamCtrl(
this(
configuration.getMillis("stream.longpolling.cache").millis,
configuration.getMillis("stream.longpolling.refresh").millis,
configuration.getMillis("stream.longpolling.nextItemMaxWait").millis,
configuration.getMillis("stream.longpolling.globalMaxWait").millis,
authenticated,
renderer,
eventSrv,
Expand All @@ -62,38 +60,36 @@ class StreamCtrl(
components,
system,
ec)
private[StreamCtrl] lazy val logger = Logger(getClass)

private val streamLength = 20
private lazy val logger = Logger(getClass)
private val mediator = DistributedPubSub(system).mediator
private val alphanumeric: immutable.IndexedSeq[Char] = ('a' to 'z') ++ ('A' to 'Z') ++ ('0' to '9')
private def generateStreamId() = Seq.fill(streamLength)(alphanumeric(Random.nextInt(alphanumeric.size))).mkString
private def isValidStreamId(streamId: String): Boolean = streamId.length == streamLength && streamId.forall(alphanumeric.contains)

/**
* Create a new stream entry with the event head
*/
@Timed("controllers.StreamCtrl.create")
def create: Action[AnyContent] = authenticated(Roles.read) {
val id = generateStreamId()
system.actorOf(Props(
val streamActor = system.actorOf(Props(
classOf[StreamActor],
cacheExpiration,
refresh,
nextItemMaxWait,
globalMaxWait,
eventSrv,
auxSrv), s"stream-$id")
refresh), s"stream-$id")
logger.debug(s"Register stream actor $streamActor")
mediator ! Put(streamActor)
Ok(id)
}

val alphanumeric: immutable.IndexedSeq[Char] = ('a' to 'z') ++ ('A' to 'Z') ++ ('0' to '9')
private[controllers] def generateStreamId() = Seq.fill(10)(alphanumeric(Random.nextInt(alphanumeric.size))).mkString
private[controllers] def isValidStreamId(streamId: String): Boolean = {
streamId.length == 10 && streamId.forall(alphanumeric.contains)
}

/**
* Get events linked to the identified stream entry
* This call waits up to "refresh", if there is no event, return empty response
*/
@Timed("controllers.StreamCtrl.get")
def get(id: String): Action[AnyContent] = Action.async { implicit request ⇒
implicit val timeout: Timeout = Timeout(refresh + globalMaxWait + 1.second)
implicit val timeout: Timeout = Timeout(refresh + 1.second)

if (!isValidStreamId(id)) {
Future.successful(BadRequest("Invalid stream id"))
Expand All @@ -105,12 +101,21 @@ class StreamCtrl(
case _ ⇒ Future.successful(OK)
}

futureStatus.flatMap { status ⇒
(system.actorSelection(s"/user/stream-$id") ? StreamActor.GetOperations) map {
case StreamMessages(operations) ⇒ renderer.toOutput(status, operations)
case m ⇒ InternalServerError(s"Unexpected message : $m (${m.getClass})")
// Check if stream actor exists
mediator.ask(Send(s"/user/stream-$id", Identify(1), localAffinity = false))(Timeout(2.seconds))
.flatMap {
case ActorIdentity(1, Some(_)) ⇒
futureStatus.flatMap { status ⇒
(mediator ? Send(s"/user/stream-$id", StreamActor.GetOperations, localAffinity = false)) map {
case StreamMessages(operations) ⇒ renderer.toOutput(status, operations)
case m ⇒ InternalServerError(s"Unexpected message : $m (${m.getClass})")
}
}
case _ ⇒ Future.successful(renderer.toOutput(NOT_FOUND, Json.obj("type" → "StreamNotFound", "message" → s"Stream $id doesn't exist")))
}
.recover {
case _: AskTimeoutException ⇒ renderer.toOutput(NOT_FOUND, Json.obj("type" → "StreamNotFound", "message" → s"Stream $id doesn't exist"))
}
}
}
}

Expand Down
12 changes: 11 additions & 1 deletion thehive-backend/app/global/TheHive.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import services._
import org.elastic4play.models.BaseModelDef
import org.elastic4play.services.auth.MultiAuthSrv
import org.elastic4play.services.{ AuthSrv, MigrationOperations }
import services.mappers.{ MultiUserMapperSrv, UserMapper }

class TheHive(
environment: Environment,
Expand All @@ -32,6 +33,7 @@ class TheHive(
val modelBindings = ScalaMultibinder.newSetBinder[BaseModelDef](binder)
val auditedModelBindings = ScalaMultibinder.newSetBinder[AuditedModel](binder)
val authBindings = ScalaMultibinder.newSetBinder[AuthSrv](binder)
val ssoMapperBindings = ScalaMultibinder.newSetBinder[UserMapper](binder)

val reflectionClasses = new Reflections(new ConfigurationBuilder()
.forPackages("org.elastic4play")
Expand Down Expand Up @@ -64,11 +66,19 @@ class TheHive(
authBindings.addBinding.to(authSrvClass)
}

reflectionClasses
.getSubTypesOf(classOf[UserMapper])
.asScala
.filterNot(c ⇒ java.lang.reflect.Modifier.isAbstract(c.getModifiers) || c.isMemberClass)
.filterNot(c ⇒ c == classOf[MultiUserMapperSrv])
.foreach(mapperCls ⇒ ssoMapperBindings.addBinding.to(mapperCls))

bind[MigrationOperations].to[Migration]
bind[AuthSrv].to[TheHiveAuthSrv]
bind[UserMapper].to[MultiUserMapperSrv]

bindActor[AuditActor]("AuditActor")
bindActor[DeadLetterMonitoringActor]("DeadLetterMonitoringActor")
bindActor[LocalStreamActor]("localStreamActor")

if (environment.mode == Mode.Prod)
bind[AssetCtrl].to[AssetCtrlProd]
Expand Down
153 changes: 153 additions & 0 deletions thehive-backend/app/services/OAuth2Srv.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package services

import javax.inject.{ Inject, Singleton }

import akka.stream.Materializer
import org.elastic4play.services.{ AuthContext, AuthSrv }
import org.elastic4play.{ AuthenticationError, AuthorizationError, OAuth2Redirect }
import play.api.http.Status
import play.api.libs.json.{ JsObject, JsValue }
import play.api.libs.ws.WSClient
import play.api.mvc.RequestHeader
import play.api.{ Configuration, Logger }
import services.mappers.UserMapper

import scala.concurrent.{ ExecutionContext, Future }

case class OAuth2Config(
clientId: Option[String] = None,
clientSecret: String,
redirectUri: String,
responseType: String,
grantType: String,
authorizationUrl: String,
tokenUrl: String,
userUrl: String,
scope: String,
autocreate: Boolean)

object OAuth2Config {
def apply(configuration: Configuration): OAuth2Config = {
(for {
clientId ← configuration.getOptional[String]("auth.oauth2.clientId")
clientSecret ← configuration.getOptional[String]("auth.oauth2.clientSecret")
redirectUri ← configuration.getOptional[String]("auth.oauth2.redirectUri")
responseType ← configuration.getOptional[String]("auth.oauth2.responseType")
grantType ← configuration.getOptional[String]("auth.oauth2.grantType")
authorizationUrl ← configuration.getOptional[String]("auth.oauth2.authorizationUrl")
userUrl ← configuration.getOptional[String]("auth.oauth2.userUrl")
tokenUrl ← configuration.getOptional[String]("auth.oauth2.tokenUrl")
scope ← configuration.getOptional[String]("auth.oauth2.scope")
autocreate ← configuration.getOptional[Boolean]("auth.sso.autocreate").orElse(Some(false))
} yield OAuth2Config(Some(clientId), clientSecret, redirectUri, responseType, grantType, authorizationUrl, tokenUrl, userUrl, scope, autocreate))
.getOrElse(OAuth2Config(tokenUrl = "", clientSecret = "", redirectUri = "", responseType = "", grantType = "", authorizationUrl = "", userUrl = "", scope = "", autocreate = false))
}
}

@Singleton
class OAuth2Srv(
ws: WSClient,
userSrv: UserSrv,
ssoMapper: UserMapper,
oauth2Config: OAuth2Config,
implicit val ec: ExecutionContext,
implicit val mat: Materializer)
extends AuthSrv {

@Inject() def this(
ws: WSClient,
ssoMapper: UserMapper,
userSrv: UserSrv,
configuration: Configuration,
ec: ExecutionContext,
mat: Materializer) = this(
ws,
userSrv,
ssoMapper,
OAuth2Config(configuration),
ec,
mat)

override val name: String = "oauth2"
private val logger = Logger(getClass)

val Oauth2TokenQueryString = "code"

override def authenticate()(implicit request: RequestHeader): Future[AuthContext] = {
oauth2Config.clientId
.fold[Future[AuthContext]](Future.failed(AuthenticationError("OAuth2 not configured properly"))) {
clientId ⇒
request.queryString
.get(Oauth2TokenQueryString)
.flatMap(_.headOption)
.fold(createOauth2Redirect(clientId)) { code ⇒
getAuthTokenAndAuthenticate(clientId, code)
}
}
}

private def getAuthTokenAndAuthenticate(clientId: String, code: String)(implicit request: RequestHeader): Future[AuthContext] = {
logger.debug("Getting user token with the code from the response!")
ws.url(oauth2Config.tokenUrl)
.post(Map(
"code" -> code,
"grant_type" -> oauth2Config.grantType,
"client_secret" -> oauth2Config.clientSecret,
"redirect_uri" -> oauth2Config.redirectUri,
"client_id" -> clientId))
.recoverWith {
case error ⇒
logger.error(s"Token verification failure", error)
Future.failed(AuthenticationError("Token verification failure"))
}
.flatMap { r ⇒
r.status match {
case Status.OK ⇒
val accessToken = (r.json \ "access_token").asOpt[String].getOrElse("")
val authHeader = "Authorization" -> s"bearer $accessToken"
ws.url(oauth2Config.userUrl)
.addHttpHeaders(authHeader)
.get().flatMap { userResponse ⇒
if (userResponse.status != Status.OK) {
Future.failed(AuthenticationError(s"unexpected response from server: ${userResponse.status} ${userResponse.body}"))
}
else {
val response = userResponse.json.asInstanceOf[JsObject]
getOrCreateUser(response, authHeader)
}
}
case _ ⇒
logger.error(s"unexpected response from server: ${r.status} ${r.body}")
Future.failed(AuthenticationError("unexpected response from server"))
}
}
}

private def getOrCreateUser(response: JsValue, authHeader: (String, String))(implicit request: RequestHeader): Future[AuthContext] = {
ssoMapper.getUserFields(response, Some(authHeader)).flatMap {
userFields ⇒
val userId = userFields.getString("login").getOrElse("")
userSrv.get(userId).flatMap(user ⇒ {
userSrv.getFromUser(request, user)
}).recoverWith {
case authErr: AuthorizationError ⇒ Future.failed(authErr)
case _ if oauth2Config.autocreate ⇒
userSrv.inInitAuthContext { implicit authContext ⇒
userSrv.create(userFields).flatMap(user ⇒ {
userSrv.getFromUser(request, user)
})
}
}
}
}

private def createOauth2Redirect(clientId: String): Future[AuthContext] = {
val queryStringParams = Map[String, Seq[String]](
"scope" -> Seq(oauth2Config.scope),
"response_type" -> Seq(oauth2Config.responseType),
"redirect_uri" -> Seq(oauth2Config.redirectUri),
"client_id" -> Seq(clientId))
Future.failed(OAuth2Redirect(oauth2Config.authorizationUrl, queryStringParams))
}
}

Loading