From a3b4fd80158513a1ce2260bb35376db2d2b9dd89 Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Fri, 20 Sep 2019 14:25:33 +1000 Subject: [PATCH] Clean and simplify use of API --- src/main/scala/ai/tripl/arc/ARC.scala | 18 ++++++++---------- src/main/scala/ai/tripl/arc/api/API.scala | 4 ++-- .../plugins/lifecycle/DataFramePrinter.scala | 8 +++----- .../arc/plugins/TestLifecyclePlugin.scala | 6 ++---- 4 files changed, 15 insertions(+), 21 deletions(-) diff --git a/src/main/scala/ai/tripl/arc/ARC.scala b/src/main/scala/ai/tripl/arc/ARC.scala index 42a9f209..e1d94cad 100644 --- a/src/main/scala/ai/tripl/arc/ARC.scala +++ b/src/main/scala/ai/tripl/arc/ARC.scala @@ -404,17 +404,17 @@ object ARC { def run(pipeline: ETLPipeline) (implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext): Option[DataFrame] = { - def before(index: Int, stages: List[PipelineStage]): Unit = { + def before(currentValue: PipelineStage, index: Int, stages: List[PipelineStage]): Unit = { for (p <- arcContext.activeLifecyclePlugins) { logger.trace().message(s"Executing before() on LifecyclePlugin: ${p.getClass.getName}") - p.before(index, stages) + p.before(currentValue, index, stages) } } - def after(currentValue: Option[DataFrame], index: Int, stages: List[PipelineStage]): Unit = { + def after(result: Option[DataFrame], currentValue: PipelineStage, index: Int, stages: List[PipelineStage]): Unit = { for (p <- arcContext.activeLifecyclePlugins) { logger.trace().message(s"Executing after on LifecyclePlugin: ${stages(index).getClass.getName}") - p.after(currentValue, index, stages) + p.after(result, currentValue, index, stages) } } @@ -425,20 +425,18 @@ object ARC { case head :: Nil => val stage = head._1 val index = head._2 - val pipelineStages = stages.map(_._1) - before(index, pipelineStages) + before(stage, index, pipeline.stages) val result = processStage(stage) - after(result, index, pipelineStages) + after(result, stage, index, pipeline.stages) result //currentValue[, index[, array]] case head :: tail => val stage = head._1 val index = head._2 - val pipelineStages = stages.map(_._1) - before(index, pipelineStages) + before(stage, index, pipeline.stages) val result = processStage(stage) - after(result, index, pipelineStages) + after(result, stage, index, pipeline.stages) runStages(tail) } } diff --git a/src/main/scala/ai/tripl/arc/api/API.scala b/src/main/scala/ai/tripl/arc/api/API.scala index 8b176a58..2c77d4a3 100644 --- a/src/main/scala/ai/tripl/arc/api/API.scala +++ b/src/main/scala/ai/tripl/arc/api/API.scala @@ -258,9 +258,9 @@ object API { def plugin: LifecyclePlugin - def before(index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) + def before(stage: PipelineStage, index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) - def after(currentValue: Option[DataFrame], index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) + def after(result: Option[DataFrame], stage: PipelineStage, index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) } diff --git a/src/main/scala/ai/tripl/arc/plugins/lifecycle/DataFramePrinter.scala b/src/main/scala/ai/tripl/arc/plugins/lifecycle/DataFramePrinter.scala index 1139e3c3..9e03f42c 100644 --- a/src/main/scala/ai/tripl/arc/plugins/lifecycle/DataFramePrinter.scala +++ b/src/main/scala/ai/tripl/arc/plugins/lifecycle/DataFramePrinter.scala @@ -42,22 +42,20 @@ case class DataFramePrinterInstance( truncate: Boolean ) extends LifecyclePluginInstance { - override def before(index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) { - val stage = stages(index) + override def before(stage: PipelineStage, index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) { logger.trace() .field("event", "before") .field("stage", stage.name) .log() } - override def after(currentValue: Option[DataFrame], index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) { - val stage = stages(index) + override def after(result: Option[DataFrame], stage: PipelineStage, index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) { logger.trace() .field("event", "after") .field("stage", stage.name) .log() - currentValue match { + result match { case Some(df) => df.show(numRows, truncate) case None => } diff --git a/src/test/scala/ai/tripl/arc/plugins/TestLifecyclePlugin.scala b/src/test/scala/ai/tripl/arc/plugins/TestLifecyclePlugin.scala index d02f4663..9a82b049 100644 --- a/src/test/scala/ai/tripl/arc/plugins/TestLifecyclePlugin.scala +++ b/src/test/scala/ai/tripl/arc/plugins/TestLifecyclePlugin.scala @@ -42,16 +42,14 @@ case class TestLifecyclePluginInstance( key: String ) extends LifecyclePluginInstance { - override def before(index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) { + override def before(stage: PipelineStage, index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) { import spark.implicits._ - val stage= stages(index) val df = Seq((stage.name, "before", this.key)).toDF("stage","when","message") df.createOrReplaceTempView("before") } - override def after(currentValue: Option[DataFrame], index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) { + override def after(currentValue: Option[DataFrame], stage: PipelineStage, index: Int, stages: List[PipelineStage])(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext) { import spark.implicits._ - val stage= stages(index) val df = Seq((stage.name, "after", this.key, currentValue.get.count)).toDF("stage","when","message","count") df.createOrReplaceTempView("after") }