Skip to content

Commit

Permalink
TheHive-Project#864 Add authentication method in authContext
Browse files Browse the repository at this point in the history
  • Loading branch information
To-om committed Feb 5, 2019
1 parent 0bcaf1c commit 2233d5c
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ object Dependencies {

val reflections = "org.reflections" % "reflections" % "0.9.11"
val zip4j = "net.lingala.zip4j" % "zip4j" % "1.3.2"
val elastic4play = "org.thehive-project" %% "elastic4play" % "1.7.1"
val elastic4play = "org.thehive-project" %% "elastic4play" % "1.7.3-SNAPSHOT"
val akkaCluster = "com.typesafe.akka" %% "akka-cluster" % "2.5.11"
val akkaClusterTools = "com.typesafe.akka" %% "akka-cluster-tools" % "2.5.11"
}
Expand Down
2 changes: 1 addition & 1 deletion thehive-backend/app/services/KeyAuthSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class KeyAuthSrv @Inject() (
.filter(_.key().contains(key))
.runWith(Sink.headOption)
.flatMap {
case Some(user) userSrv.getFromUser(request, user)
case Some(user) userSrv.getFromUser(request, user, name)
case None Future.failed(AuthenticationError("Authentication failure"))
}
}
Expand Down
2 changes: 1 addition & 1 deletion thehive-backend/app/services/LocalAuthSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class LocalAuthSrv @Inject() (

override def authenticate(username: String, password: String)(implicit request: RequestHeader): Future[AuthContext] = {
userSrv.get(username).flatMap { user
if (doAuthenticate(user, password)) userSrv.getFromUser(request, user)
if (doAuthenticate(user, password)) userSrv.getFromUser(request, user, name)
else Future.failed(AuthenticationError("Authentication failure"))
}
}
Expand Down
4 changes: 2 additions & 2 deletions thehive-backend/app/services/OAuth2Srv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,13 @@ class OAuth2Srv(
userFields
val userId = userFields.getString("login").getOrElse("")
userSrv.get(userId).flatMap(user {
userSrv.getFromUser(request, user)
userSrv.getFromUser(request, user, name)
}).recoverWith {
case authErr: AuthorizationError Future.failed(authErr)
case _ if cfg.autocreate
userSrv.inInitAuthContext { implicit authContext
userSrv.create(userFields).flatMap(user {
userSrv.getFromUser(request, user)
userSrv.getFromUser(request, user, name)
})
}
}
Expand Down
16 changes: 8 additions & 8 deletions thehive-backend/app/services/UserSrv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ class UserSrv @Inject() (
dbIndex: DBIndex,
implicit val ec: ExecutionContext) extends org.elastic4play.services.UserSrv {

private case class AuthContextImpl(userId: String, userName: String, requestId: String, roles: Seq[Role]) extends AuthContext
private case class AuthContextImpl(userId: String, userName: String, requestId: String, roles: Seq[Role], authMethod: String) extends AuthContext

override def getFromId(request: RequestHeader, userId: String): Future[AuthContext] = {
override def getFromId(request: RequestHeader, userId: String, authMethod: String): Future[AuthContext] = {
getSrv[UserModel, User](userModel, userId)
.flatMap { user getFromUser(request, user) }
.flatMap { user getFromUser(request, user, authMethod) }
}

override def getFromUser(request: RequestHeader, user: org.elastic4play.services.User): Future[AuthContext] = {
override def getFromUser(request: RequestHeader, user: org.elastic4play.services.User, authMethod: String): Future[AuthContext] = {
user match {
case u: User if u.status() == UserStatus.Ok Future.successful(AuthContextImpl(user.id, user.getUserName, Instance.getRequestId(request), user.getRoles))
case u: User if u.status() == UserStatus.Ok Future.successful(AuthContextImpl(user.id, user.getUserName, Instance.getRequestId(request), user.getRoles, authMethod))
case _ Future.failed(AuthorizationError("Your account is locked"))
}

Expand All @@ -47,19 +47,19 @@ class UserSrv @Inject() (
override def getInitialUser(request: RequestHeader): Future[AuthContext] =
dbIndex.getSize(userModel.modelName).map {
case size if size > 0 throw AuthenticationError(s"Use of initial user is forbidden because users exist in database")
case _ AuthContextImpl("init", "", Instance.getRequestId(request), Seq(Roles.admin, Roles.read, Roles.alert))
case _ AuthContextImpl("init", "", Instance.getRequestId(request), Seq(Roles.admin, Roles.read, Roles.alert), "init")
}

override def inInitAuthContext[A](block: AuthContext Future[A]): Future[A] = {
val authContext = AuthContextImpl("init", "", Instance.getInternalId, Seq(Roles.admin, Roles.read, Roles.alert))
val authContext = AuthContextImpl("init", "", Instance.getInternalId, Seq(Roles.admin, Roles.read, Roles.alert), "init")
eventSrv.publish(InternalRequestProcessStart(authContext.requestId))
block(authContext).andThen {
case _ eventSrv.publish(InternalRequestProcessEnd(authContext.requestId))
}
}

def extraAuthContext[A](block: AuthContext Future[A])(implicit authContext: AuthContext): Future[A] = {
val ac = AuthContextImpl(authContext.userId, authContext.userName, Instance.getInternalId, authContext.roles)
val ac = AuthContextImpl(authContext.userId, authContext.userName, Instance.getInternalId, authContext.roles, "init")
eventSrv.publish(InternalRequestProcessStart(ac.requestId))
block(ac).andThen {
case _ eventSrv.publish(InternalRequestProcessEnd(ac.requestId))
Expand Down

0 comments on commit 2233d5c

Please sign in to comment.