Skip to content

Commit

Permalink
Clean and simplify use of API
Browse files Browse the repository at this point in the history
  • Loading branch information
seddonm1 committed Sep 20, 2019
1 parent 573f7b3 commit a3b4fd8
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 21 deletions.
18 changes: 8 additions & 10 deletions src/main/scala/ai/tripl/arc/ARC.scala
Original file line number Diff line number Diff line change
Expand Up @@ -404,17 +404,17 @@ object ARC {
def run(pipeline: ETLPipeline)
(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext): Option[DataFrame] = {

def before(index: Int, stages: List[PipelineStage]): Unit = {
def before(currentValue: PipelineStage, index: Int, stages: List[PipelineStage]): Unit = {
for (p <- arcContext.activeLifecyclePlugins) {
logger.trace().message(s"Executing before() on LifecyclePlugin: ${p.getClass.getName}")
p.before(index, stages)
p.before(currentValue, index, stages)
}
}

def after(currentValue: Option[DataFrame], index: Int, stages: List[PipelineStage]): Unit = {
def after(result: Option[DataFrame], currentValue: PipelineStage, index: Int, stages: List[PipelineStage]): Unit = {
for (p <- arcContext.activeLifecyclePlugins) {
logger.trace().message(s"Executing after on LifecyclePlugin: ${stages(index).getClass.getName}")
p.after(currentValue, index, stages)
p.after(result, currentValue, index, stages)
}
}

Expand All @@ -425,20 +425,18 @@ object ARC {
case head :: Nil =>
val stage = head._1
val index = head._2
val pipelineStages = stages.map(_._1)
before(index, pipelineStages)
before(stage, index, pipeline.stages)
val result = processStage(stage)
after(result, index, pipelineStages)
after(result, stage, index, pipeline.stages)
result

//currentValue[, index[, array]]
case head :: tail =>
val stage = head._1
val index = head._2
val pipelineStages = stages.map(_._1)
before(index, pipelineStages)
before(stage, index, pipeline.stages)
val result = processStage(stage)
after(result, index, pipelineStages)
after(result, stage, index, pipeline.stages)
runStages(tail)
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/ai/tripl/arc/api/API.scala
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ object API {

def plugin: LifecyclePlugin

def before(index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext)
def before(stage: PipelineStage, index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext)

def after(currentValue: Option[DataFrame], index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext)
def after(result: Option[DataFrame], stage: PipelineStage, index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext)

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,20 @@ case class DataFramePrinterInstance(
truncate: Boolean
) extends LifecyclePluginInstance {

override def before(index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) {
val stage = stages(index)
override def before(stage: PipelineStage, index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) {
logger.trace()
.field("event", "before")
.field("stage", stage.name)
.log()
}

override def after(currentValue: Option[DataFrame], index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) {
val stage = stages(index)
override def after(result: Option[DataFrame], stage: PipelineStage, index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) {
logger.trace()
.field("event", "after")
.field("stage", stage.name)
.log()

currentValue match {
result match {
case Some(df) => df.show(numRows, truncate)
case None =>
}
Expand Down
6 changes: 2 additions & 4 deletions src/test/scala/ai/tripl/arc/plugins/TestLifecyclePlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,14 @@ case class TestLifecyclePluginInstance(
key: String
) extends LifecyclePluginInstance {

override def before(index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) {
override def before(stage: PipelineStage, index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) {
import spark.implicits._
val stage= stages(index)
val df = Seq((stage.name, "before", this.key)).toDF("stage","when","message")
df.createOrReplaceTempView("before")
}

override def after(currentValue: Option[DataFrame], index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) {
override def after(currentValue: Option[DataFrame], stage: PipelineStage, index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) {
import spark.implicits._
val stage= stages(index)
val df = Seq((stage.name, "after", this.key, currentValue.get.count)).toDF("stage","when","message","count")
df.createOrReplaceTempView("after")
}
Expand Down

0 comments on commit a3b4fd8

Please sign in to comment.