Skip to content

Commit 32d4887

Browse files
committed
Merge remote-tracking branch 'origin/master'
2 parents c92420e + cbdb973 commit 32d4887

File tree

3 files changed

+448
-2
lines changed

3 files changed

+448
-2
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ How to get a stratified sample so the test and train datasets are sampled accros
1717
### [Decision Tree with Categorical Feature in the DataSet](https://github.com/aosama/MachineLearningSamples/blob/master/src/main/scala/org/ibrahim/ezmachinelearning/DTShapeTypeWithCategoricalFeaturesExample.scala)
1818
How to index and encode categorical features.
1919

20-
### [Decision Tree Multiple Categorical and Continuous Features in the DataSet](https://github.com/aosama/MachineLearningSamples/blob/master/src/main/scala/org/ibrahim/ezmachinelearning/DTCensusIncomeExample.scala)
20+
### [Predicting Income Based on Census Data Using Decision Tree](https://github.com/aosama/MachineLearningSamples/blob/master/src/main/scala/org/ibrahim/ezmachinelearning/DTCensusIncomeExample.scala)
2121
How to handle multiple categorical and continuous features on a real-life data set.
2222
Uses the Census Income data set.
2323

24-
### [Random Forest Multiple Categorical and Continuous Features in the DataSet](https://github.com/aosama/MachineLearningSamples/blob/master/src/main/scala/org/ibrahim/ezmachinelearning/RFCensusIncomeExample.scala)
24+
### [Predicting Income Based on Census Data Using Random Decision Forest](https://github.com/aosama/MachineLearningSamples/blob/master/src/main/scala/org/ibrahim/ezmachinelearning/RFCensusIncomeExample.scala)
2525
How to handle multiple categorical and continuous features on a real-life data set.
2626
Uses the Census Income data set.
2727

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
// Databricks notebook source
2+
import org.apache.spark.sql.{DataFrame, functions}
3+
4+
def formatData(df: DataFrame, fields: Seq[String], continuousFieldIndexes: Seq[Int]): DataFrame = {
5+
var data = df
6+
7+
// Trim leading spaces from data
8+
for (colName <- data.columns)
9+
data = data.withColumn(colName, functions.ltrim(functions.col(colName)))
10+
11+
// Assign column names
12+
for (i <- fields.indices)
13+
data = data.withColumnRenamed("_c" + i, fields(i))
14+
15+
data = data.withColumnRenamed("_c14", "label")
16+
17+
// Convert continuous values from string to double
18+
for (i <- continuousFieldIndexes) {
19+
data = data.withColumn(fields(i), functions.col(fields(i)).cast("double"))
20+
}
21+
22+
// Remove '.' character from label
23+
data = data.withColumn("label", functions.regexp_replace(functions.col("label"), "\\.", ""))
24+
25+
data
26+
}
27+
28+
def showCategories(df: DataFrame, fields: Seq[String], categoricalFieldIndexes: Seq[Int]): Unit = {
29+
for (i <- categoricalFieldIndexes) {
30+
val colName = fields(i)
31+
df.select(colName + "Indexed", colName).distinct().sort(colName + "Indexed").show(100)
32+
}
33+
}
34+
35+
// COMMAND ----------
36+
37+
val fields = Seq(
38+
"age",
39+
"workclass",
40+
"fnlwgt",
41+
"education",
42+
"education-num",
43+
"marital-status",
44+
"occupation",
45+
"relationship",
46+
"race",
47+
"sex",
48+
"capital-gain",
49+
"capital-loss",
50+
"hours-per-week",
51+
"native-country"
52+
)
53+
54+
val categoricalFieldIndexes = Seq(1, 3, 5, 6, 7, 8, 9, 13)
55+
val continuousFieldIndexes = Seq(0, 2, 4, 10, 11, 12)
56+
57+
// COMMAND ----------
58+
59+
// Create dataframe to hold census income training data
60+
// Data retrieved from http://archive.ics.uci.edu/ml/datasets/Census+Income
61+
val trainingUrl = "https://raw.githubusercontent.com/aosama/MachineLearningSamples/master/src/main/resources/adult.data"
62+
val trainingContent = scala.io.Source.fromURL(trainingUrl).mkString
63+
64+
val trainingList = trainingContent.split("\n").filter(_ != "")
65+
66+
val trainingDs = sc.parallelize(trainingList).toDS()
67+
var trainingData = spark.read.csv(trainingDs).cache
68+
69+
// COMMAND ----------
70+
71+
// Create dataframe to hold census income test data
72+
// Data retrieved from http://archive.ics.uci.edu/ml/datasets/Census+Income
73+
val testUrl = "https://raw.githubusercontent.com/aosama/MachineLearningSamples/master/src/main/resources/adult.test"
74+
val testContent = scala.io.Source.fromURL(testUrl).mkString
75+
76+
val testList = testContent.split("\n").filter(_ != "")
77+
78+
val testDs = sc.parallelize(testList).toDS()
79+
var testData = spark.read.csv(testDs).cache
80+
81+
// COMMAND ----------
82+
83+
// Format the data
84+
trainingData = formatData(trainingData, fields, continuousFieldIndexes)
85+
testData = formatData(testData, fields, continuousFieldIndexes)
86+
87+
// COMMAND ----------
88+
89+
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorAssembler}
90+
91+
// Create object to convert categorical values to index values
92+
val categoricalIndexerArray =
93+
for (i <- categoricalFieldIndexes)
94+
yield new StringIndexer()
95+
.setInputCol(fields(i))
96+
.setOutputCol(fields(i) + "Indexed")
97+
98+
// Create object to index label values
99+
val labelIndexer = new StringIndexer()
100+
.setInputCol("label")
101+
.setOutputCol("indexedLabel")
102+
.fit(trainingData)
103+
104+
// Create object to generate feature vector from categorical and continuous values
105+
val vectorAssembler = new VectorAssembler()
106+
.setInputCols((categoricalFieldIndexes.map(i => fields(i) + "Indexed") ++ continuousFieldIndexes.map(i => fields(i))).toArray)
107+
.setOutputCol("features")
108+
109+
// Create object to convert indexed labels back to actual labels for predictions
110+
val labelConverter = new IndexToString()
111+
.setInputCol("prediction")
112+
.setOutputCol("predictedLabel")
113+
.setLabels(labelIndexer.labels)
114+
115+
// COMMAND ----------
116+
117+
import org.apache.spark.ml.Pipeline
118+
import org.apache.spark.ml.classification.DecisionTreeClassifier
119+
120+
// Create decision tree
121+
val dt = new DecisionTreeClassifier()
122+
.setLabelCol("indexedLabel")
123+
.setFeaturesCol("features")
124+
.setMaxBins(50) // Since feature "native-country" contains 42 distinct values, need to increase max bins.
125+
.setMaxDepth(6)
126+
127+
// Array of stages to run in pipeline
128+
val indexerArray = Array(labelIndexer) ++ categoricalIndexerArray
129+
val stageArray = indexerArray ++ Array(vectorAssembler, dt, labelConverter)
130+
131+
val pipeline = new Pipeline()
132+
.setStages(stageArray)
133+
134+
// Train the model
135+
val model = pipeline.fit(trainingData)
136+
137+
// Test the model
138+
val predictions = model.transform(testData)
139+
140+
// COMMAND ----------
141+
142+
display(predictions.select("label", Seq("predictedLabel" ,"indexedLabel", "prediction") ++ fields:_*))
143+
144+
// COMMAND ----------
145+
146+
val wrongPredictions = predictions
147+
.select("label", Seq("predictedLabel" ,"indexedLabel", "prediction") ++ fields:_*)
148+
.where("indexedLabel != prediction")
149+
display(wrongPredictions)
150+
151+
// COMMAND ----------
152+
153+
// Show the label and all the categorical features mapped to indexes
154+
val indexedData = new Pipeline()
155+
.setStages(indexerArray)
156+
.fit(trainingData)
157+
.transform(trainingData)
158+
indexedData.select("indexedLabel", "label").distinct().sort("indexedLabel").show()
159+
showCategories(indexedData, fields, categoricalFieldIndexes)
160+
161+
// COMMAND ----------
162+
163+
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
164+
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
165+
import org.apache.spark.mllib.evaluation.MulticlassMetrics
166+
167+
val evaluator = new MulticlassClassificationEvaluator()
168+
.setLabelCol("indexedLabel")
169+
.setPredictionCol("prediction")
170+
.setMetricName("accuracy")
171+
172+
val accuracy = evaluator.evaluate(predictions)
173+
println(s"Test error = ${1.0 - accuracy}\n")
174+
175+
val metrics = new MulticlassMetrics(
176+
predictions.select("indexedLabel", "prediction")
177+
.rdd.map(x => (x.getDouble(0), x.getDouble(1)))
178+
)
179+
180+
println(s"Confusion matrix:\n ${metrics.confusionMatrix}\n")
181+
182+
val treeModel = model.stages(stageArray.length - 2).asInstanceOf[DecisionTreeClassificationModel]
183+
184+
// Print out the tree with actual column names for features
185+
var treeModelString = treeModel.toDebugString
186+
187+
val featureFieldIndexes = categoricalFieldIndexes ++ continuousFieldIndexes
188+
for (i <- featureFieldIndexes.indices)
189+
treeModelString = treeModelString
190+
.replace("feature " + i + " ", fields(featureFieldIndexes(i)) + " ")
191+
192+
println(s"Learned classification tree model:\n $treeModelString")
193+
194+
// COMMAND ----------
195+
196+
for (i <- featureFieldIndexes.indices)
197+
println(s"feature " + i + " -> " + fields(featureFieldIndexes(i)))
198+
199+
// COMMAND ----------
200+
201+
display(treeModel)
202+
203+
// COMMAND ----------
204+
205+
display(testData.filter('age === 25))
206+
207+
// COMMAND ----------
208+
209+
testData.printSchema
210+
211+
// COMMAND ----------
212+
213+
import org.apache.spark.ml.linalg.Vector
214+
val vectorElem = udf{ (x:Vector,i:Int) => x(i) }
215+
val predictionsExpanded = predictions.withColumn("rawPrediction0",vectorElem('rawPrediction,functions.lit(0)))
216+
.withColumn("rawPrediction1",vectorElem('rawPrediction,functions.lit(1)))
217+
.withColumn("score0",vectorElem('probability,functions.lit(0)))
218+
.withColumn("score1",vectorElem('probability,functions.lit(1)))
219+
220+
// COMMAND ----------
221+
222+
display(predictionsExpanded.orderBy($"age".asc))
223+
224+
// COMMAND ----------
225+
226+
val record = Seq((50,"Private",220931,"Bachelors",13,"Married-civ-spouse","Prof-specialty","Not-in-family","White","Male",10,0,43,"United-States")).toDF("age",
227+
"workclass",
228+
"fnlwgt",
229+
"education",
230+
"education-num",
231+
"marital-status",
232+
"occupation",
233+
"relationship",
234+
"race",
235+
"sex",
236+
"capital-gain",
237+
"capital-loss",
238+
"hours-per-week",
239+
"native-country")
240+
241+
// COMMAND ----------
242+
243+
val singlePrediction = model.transform(record)
244+
.withColumn("rawPrediction0",vectorElem('rawPrediction,functions.lit(0)))
245+
.withColumn("rawPrediction1",vectorElem('rawPrediction,functions.lit(1)))
246+
.withColumn("score0",vectorElem('probability,functions.lit(0)))
247+
.withColumn("score1",vectorElem('probability,functions.lit(1)))
248+
249+
// COMMAND ----------
250+
251+
display(singlePrediction)
252+
253+
// COMMAND ----------
254+
255+
display(trainingData.groupBy('age).count.orderBy('age.asc))

0 commit comments

Comments
 (0)