Skip to content

[SPARK-19899][ML] Replace featuresCol with itemsCol in ml.fpm.FPGrowth #17321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol}
import org.apache.spark.ml.param.shared.HasPredictionCol
import org.apache.spark.ml.util._
import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules,
FPGrowth => MLlibFPGrowth}
Expand All @@ -37,7 +37,20 @@ import org.apache.spark.sql.types._
/**
* Common params for FPGrowth and FPGrowthModel
*/
private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPredictionCol {
private[fpm] trait FPGrowthParams extends Params with HasPredictionCol {

/**
* Items column name.
* Default: "items"
* @group param
*/
@Since("2.2.0")
Copy link
Contributor

@MLnick MLnick Mar 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We normally don't annotate in traits since later things may be overridden by concrete implementers. Although I do see a few places where it is done...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FPGrowthParams has been annotated before so I decided to make it consistent. I can remove annotation here, but then we should probably follow through and remove the rest.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's OK to have it here. Only FPGrowth and FPGrowthModel will ever inherit from FPGrowthParams.

val itemsCol: Param[String] = new Param[String](this, "itemsCol", "items column name")
setDefault(itemsCol -> "items")

/** @group getParam */
@Since("2.2.0")
def getItemsCol: String = $(itemsCol)

/**
* Minimal support level of the frequent pattern. [0.0, 1.0]. Any pattern that appears
Expand Down Expand Up @@ -91,10 +104,10 @@ private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPre
*/
@Since("2.2.0")
protected def validateAndTransformSchema(schema: StructType): StructType = {
val inputType = schema($(featuresCol)).dataType
val inputType = schema($(itemsCol)).dataType
require(inputType.isInstanceOf[ArrayType],
s"The input column must be ArrayType, but got $inputType.")
SchemaUtils.appendColumn(schema, $(predictionCol), schema($(featuresCol)).dataType)
SchemaUtils.appendColumn(schema, $(predictionCol), schema($(itemsCol)).dataType)
}
}

Expand Down Expand Up @@ -133,7 +146,7 @@ class FPGrowth @Since("2.2.0") (

/** @group setParam */
@Since("2.2.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
def setItemsCol(value: String): this.type = set(itemsCol, value)

/** @group setParam */
@Since("2.2.0")
Expand All @@ -146,8 +159,8 @@ class FPGrowth @Since("2.2.0") (
}

private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = {
val data = dataset.select($(featuresCol))
val items = data.where(col($(featuresCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray)
val data = dataset.select($(itemsCol))
val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray)
val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport))
if (isSet(numPartitions)) {
mllibFP.setNumPartitions($(numPartitions))
Expand All @@ -156,7 +169,7 @@ class FPGrowth @Since("2.2.0") (
val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq))

val schema = StructType(Seq(
StructField("items", dataset.schema($(featuresCol)).dataType, nullable = false),
StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false),
StructField("freq", LongType, nullable = false)))
val frequentItems = dataset.sparkSession.createDataFrame(rows, schema)
copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this)
Expand Down Expand Up @@ -198,7 +211,7 @@ class FPGrowthModel private[ml] (

/** @group setParam */
@Since("2.2.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
def setItemsCol(value: String): this.type = set(itemsCol, value)

/** @group setParam */
@Since("2.2.0")
Expand Down Expand Up @@ -235,7 +248,7 @@ class FPGrowthModel private[ml] (
.collect().asInstanceOf[Array[(Seq[Any], Seq[Any])]]
val brRules = dataset.sparkSession.sparkContext.broadcast(rules)

val dt = dataset.schema($(featuresCol)).dataType
val dt = dataset.schema($(itemsCol)).dataType
// For each rule, examine the input items and summarize the consequents
val predictUDF = udf((items: Seq[_]) => {
if (items != null) {
Expand All @@ -249,7 +262,7 @@ class FPGrowthModel private[ml] (
} else {
Seq.empty
}}, dt)
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
dataset.withColumn($(predictionCol), predictUDF(col($(itemsCol))))
}

@Since("2.2.0")
Expand Down
14 changes: 7 additions & 7 deletions mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul

test("FPGrowth fit and transform with different data types") {
Array(IntegerType, StringType, ShortType, LongType, ByteType).foreach { dt =>
val data = dataset.withColumn("features", col("features").cast(ArrayType(dt)))
val data = dataset.withColumn("items", col("items").cast(ArrayType(dt)))
val model = new FPGrowth().setMinSupport(0.5).fit(data)
val generatedRules = model.setMinConfidence(0.5).associationRules
val expectedRules = spark.createDataFrame(Seq(
Expand All @@ -52,8 +52,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
(0, Array("1", "2"), Array.emptyIntArray),
(0, Array("1", "2"), Array.emptyIntArray),
(0, Array("1", "3"), Array(2))
)).toDF("id", "features", "prediction")
.withColumn("features", col("features").cast(ArrayType(dt)))
)).toDF("id", "items", "prediction")
.withColumn("items", col("items").cast(ArrayType(dt)))
.withColumn("prediction", col("prediction").cast(ArrayType(dt)))
assert(expectedTransformed.collect().toSet.equals(
transformed.collect().toSet))
Expand All @@ -79,7 +79,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
(1, Array("1", "2", "3", "5")),
(2, Array("1", "2", "3", "4")),
(3, null.asInstanceOf[Array[String]])
)).toDF("id", "features")
)).toDF("id", "items")
val model = new FPGrowth().setMinSupport(0.7).fit(dataset)
val prediction = model.transform(df)
assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty)
Expand Down Expand Up @@ -108,11 +108,11 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val dataset = spark.createDataFrame(Seq(
Array("1", "3"),
Array("2", "3")
).map(Tuple1(_))).toDF("features")
).map(Tuple1(_))).toDF("items")
val model = new FPGrowth().fit(dataset)

val prediction = model.transform(
spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features")
spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items")
).first().getAs[Seq[String]]("prediction")

assert(prediction === Seq("3"))
Expand All @@ -127,7 +127,7 @@ object FPGrowthSuite {
(0, Array("1", "2")),
(0, Array("1", "2")),
(0, Array("1", "3"))
)).toDF("id", "features")
)).toDF("id", "items")
}

/**
Expand Down