Skip to content

Commit

Permalink
follow comment
Browse files Browse the repository at this point in the history
  • Loading branch information
AngersZhuuuu committed Oct 27, 2023
1 parent 2c7ae5d commit f6b6dcc
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class RuleAuthorization(spark: SparkSession) extends Rule[LogicalPlan] {

object RuleAuthorization {

val KYUUBI_AUTHZ_TAG = TreeNodeTag[Boolean]("__KYUUBI_AUTHZ_TAG")
val KYUUBI_AUTHZ_TAG = TreeNodeTag[Unit]("__KYUUBI_AUTHZ_TAG")

private def checkPrivileges(spark: SparkSession, plan: LogicalPlan): LogicalPlan = {
val auditHandler = new SparkRangerAuditHandler
Expand Down Expand Up @@ -101,16 +101,16 @@ object RuleAuthorization {
plan match {
case _: PermanentViewMarker =>
plan.transformUp { case p =>
p.setTagValue(KYUUBI_AUTHZ_TAG, true)
p.setTagValue(KYUUBI_AUTHZ_TAG, ())
p
}
case _ =>
plan.setTagValue(KYUUBI_AUTHZ_TAG, true)
plan.setTagValue(KYUUBI_AUTHZ_TAG, ())
}
plan
}

private def isAuthChecked(plan: LogicalPlan): Boolean = {
plan.find(_.getTagValue(KYUUBI_AUTHZ_TAG).contains(true)).nonEmpty
plan.find(_.getTagValue(KYUUBI_AUTHZ_TAG).nonEmpty).nonEmpty
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ abstract class RangerSparkExtensionSuite extends AnyFunSuite
if (i == 1) {
assert(logicalPlan.getTagValue(KYUUBI_AUTHZ_TAG).isEmpty)
} else {
assert(logicalPlan.getTagValue(KYUUBI_AUTHZ_TAG).getOrElse(false))
assert(logicalPlan.getTagValue(KYUUBI_AUTHZ_TAG).nonEmpty)
}
rule.apply(logicalPlan)
}

assert(logicalPlan.getTagValue(KYUUBI_AUTHZ_TAG).getOrElse(false))
assert(logicalPlan.getTagValue(KYUUBI_AUTHZ_TAG).nonEmpty)
}

test("[KYUUBI #3226]: Another session should also check even if the plan is cached.") {
Expand All @@ -140,15 +140,15 @@ abstract class RangerSparkExtensionSuite extends AnyFunSuite
// session1: first query, should auth once.[LogicalRelation]
val df = sql(select)
val plan1 = df.queryExecution.optimizedPlan
assert(plan1.getTagValue(KYUUBI_AUTHZ_TAG).getOrElse(false))
assert(plan1.getTagValue(KYUUBI_AUTHZ_TAG).nonEmpty)

// cache
df.cache()

// session1: second query, should auth once.[InMemoryRelation]
// (don't need to check in again, but it's okay to check in once)
val plan2 = sql(select).queryExecution.optimizedPlan
assert(plan1 != plan2 && plan2.getTagValue(KYUUBI_AUTHZ_TAG).getOrElse(false))
assert(plan1 != plan2 && plan2.getTagValue(KYUUBI_AUTHZ_TAG).nonEmpty)

// session2: should auth once.
val otherSessionDf = spark.newSession().sql(select)
Expand All @@ -159,7 +159,7 @@ abstract class RangerSparkExtensionSuite extends AnyFunSuite
// make sure it use cache.
assert(plan3.isInstanceOf[InMemoryRelation])
// auth once only.
assert(plan3.getTagValue(KYUUBI_AUTHZ_TAG).getOrElse(false))
assert(plan3.getTagValue(KYUUBI_AUTHZ_TAG).nonEmpty)
})
}
}
Expand Down

0 comments on commit f6b6dcc

Please sign in to comment.