Skip to content
This repository was archived by the owner on May 12, 2021. It is now read-only.

Commit 4addcd5

Browse files
committed
Initial commit for BIDMach, VW & SPPMI [rolled up to reduce repo size]; excluding libs
1 parent 36fed5b commit 4addcd5

15 files changed

+16664
-112
lines changed

build.sbt

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,20 @@ name := "org.template.textclassification"
33

44
organization := "io.prediction"
55

6+
scalaVersion := "2.10.5"
7+
68
libraryDependencies ++= Seq(
7-
"io.prediction" %% "core" % pioVersion.value % "provided",
8-
"org.apache.spark" %% "spark-core" % "1.3.1" % "provided",
9-
"org.apache.spark" %% "spark-mllib" % "1.3.1" % "provided",
9+
"io.prediction" % "core_2.10" % pioVersion.value % "provided",
10+
"org.apache.spark" %% "spark-core" % "1.4.1" % "provided",
11+
"org.apache.spark" %% "spark-mllib" % "1.4.1" % "provided",
12+
"com.github.fommil.netlib" % "all" % "1.1.2" pomOnly(),
13+
"com.github.johnlangford" % "vw-jni" % "8.0.0",
1014
"org.xerial.snappy" % "snappy-java" % "1.1.1.7"
1115
)
16+
17+
mergeStrategy in assembly <<= (mergeStrategy in assembly) { (old) =>
18+
{
19+
case y if y.startsWith("doc") => MergeStrategy.discard
20+
case x => old(x)
21+
}
22+
}

data/Twitter140sample.txt

Lines changed: 16000 additions & 0 deletions
Large diffs are not rendered by default.

data/import_eventserver.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""
2+
Import sample data for classification engine
3+
"""
4+
5+
import predictionio
6+
import argparse
7+
8+
def import_events(client, file):
9+
f = open(file, 'r')
10+
count = 0
11+
print "Importing data..."
12+
for line in f:
13+
data = line.rstrip('\r\n').split(",")
14+
plan = data[0]
15+
#Not strictly CSV, after the first comma, no longer delimiting
16+
text = ",".join(data[1:])
17+
client.create_event(
18+
event="$set",
19+
entity_type="user",
20+
entity_id=str(count), # use the count num as user ID
21+
properties= {
22+
"text" : text,
23+
"category" : plan,
24+
"label" : int(plan)
25+
}
26+
)
27+
count += 1
28+
f.close()
29+
print "%s events are imported." % count
30+
31+
if __name__ == '__main__':
32+
parser = argparse.ArgumentParser(
33+
description="Import sample data for classification engine")
34+
parser.add_argument('--access_key', default='invald_access_key')
35+
parser.add_argument('--url', default="http://localhost:7070")
36+
parser.add_argument('--file', default="./data/Twitter140sample.txt")
37+
38+
args = parser.parse_args()
39+
print args
40+
41+
client = predictionio.EventClient(
42+
access_key=args.access_key,
43+
url=args.url,
44+
threads=5,
45+
qsize=500)
46+
import_events(client, args.file)

engine.json

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,27 @@
44
"engineFactory": "org.template.textclassification.TextClassificationEngine",
55
"datasource": {
66
"params": {
7-
"appName": "MyTextApp"
7+
"appName": "smallerData"
88
}
99
},
1010
"preparator": {
1111
"params": {
12-
"nGram": 2,
13-
"numFeatures": 15000
12+
"nGram": 1,
13+
"numFeatures": 500,
14+
"SPPMI": false
1415
}
1516
},
1617
"algorithms": [
1718
{
18-
"name": "nb",
19+
"name": "bid-lr",
1920
"params": {
20-
"lambda": 0.25
21+
"maxIter": 1,
22+
"regParam": 0.00000005,
23+
"stepSize": 5.0,
24+
"bitPrecision": 22,
25+
"modelName": "model.vw",
26+
"namespace": "n",
27+
"ngram": 1
2128
}
2229
}
2330
]

getnativepath.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
public class getnativepath {
2+
public static void main(String [] args)
3+
{
4+
String v = System.getProperty("java.library.path");
5+
System.out.print(v);
6+
}
7+
}
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
package org.template.textclassification
2+
3+
import java.io.{InputStreamReader, BufferedReader, ByteArrayInputStream, Serializable}
4+
5+
import BIDMat.{CMat,CSMat,DMat,Dict,FMat,FND,GMat,GDMat,GIMat,GLMat,GSMat,GSDMat,HMat,IDict,Image,IMat,LMat,Mat,SMat,SBMat,SDMat}
6+
import BIDMat.MatFunctions._
7+
import BIDMat.SciFunctions._
8+
import BIDMat.Solvers._
9+
import BIDMat.Plotting._
10+
import BIDMach.Learner
11+
import BIDMach.models.{FM,GLM,KMeans,KMeansw,LDA,LDAgibbs,Model,NMF,SFA,RandomForest}
12+
import BIDMach.networks.{DNN}
13+
import BIDMach.datasources.{DataSource,MatDS,FilesDS,SFilesDS}
14+
import BIDMach.mixins.{CosineSim,Perplexity,Top,L1Regularizer,L2Regularizer}
15+
import BIDMach.updaters.{ADAGrad,Batch,BatchNorm,IncMult,IncNorm,Telescoping}
16+
import BIDMach.causal.{IPTW}
17+
18+
import io.prediction.controller.{P2LAlgorithm, Params}
19+
import org.apache.spark.SparkContext
20+
import org.apache.spark.ml.classification.LogisticRegression
21+
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector}
22+
import org.apache.spark.rdd.RDD
23+
import org.apache.spark.sql.DataFrame
24+
25+
case class BIDMachLRAlgorithmParams (
26+
regParam : Double
27+
) extends Params
28+
29+
30+
class BIDMachLRAlgorithm(
31+
val sap: BIDMachLRAlgorithmParams
32+
) extends P2LAlgorithm[PreparedData, NativeLRModel, Query, PredictedResult] {
33+
// Train your model.
34+
def train(sc: SparkContext, pd: PreparedData): NativeLRModel = {
35+
new BIDMachLRModel(sc, pd, sap.regParam)
36+
}
37+
38+
// Prediction method for trained model.
39+
def predict(model: NativeLRModel, query: Query): PredictedResult = {
40+
model.predict(query.text)
41+
}
42+
43+
}
44+
45+
class BIDMachLRModel (
46+
sc : SparkContext,
47+
pd : PreparedData,
48+
regParam : Double
49+
) extends Serializable with NativeLRModel {
50+
51+
private val labels: Seq[Double] = pd.categoryMap.keys.toSeq
52+
53+
val data = prepareDataFrame(sc, pd, labels)
54+
55+
private val lrModels = fitLRModels
56+
57+
def fitLRModels:Seq[(Double, LREstimate)] = {
58+
59+
Mat.checkMKL
60+
Mat.checkCUDA
61+
if (Mat.hasCUDA > 0) GPUmem
62+
63+
// 3. Create a logistic regression model for each class.
64+
val lrModels: Seq[(Double, LREstimate)] = labels.map(
65+
label => {
66+
val lab = label.toInt.toString
67+
68+
val (categories, features) = getFMatsFromData(lab, data)
69+
70+
val mm: Learner = trainGLM(features, FMat(categories))
71+
72+
test(categories, features, mm)
73+
val modelmat = FMat(mm.modelmat)
74+
val weightSize = size(modelmat)._2 -1
75+
76+
val weights = modelmat(1,0 to weightSize)
77+
78+
val weightArray = (for(i <- 0 to weightSize -1) yield weights(0,i).toDouble).toArray
79+
80+
// Return (label, feature coefficients, and intercept term.
81+
(label, LREstimate(weightArray, weights(0,weightSize)))
82+
}
83+
)
84+
lrModels
85+
}
86+
87+
def predict(text : String): PredictedResult = {
88+
predict(text, pd, lrModels)
89+
}
90+
91+
def trainGLM(traindata:SMat, traincats: FMat): Learner = {
92+
//min(traindata, 1, traindata) // the first "traindata" argument is the input, the other is output
93+
94+
val (mm, mopts) = GLM.learner(traindata, traincats, GLM.logistic)
95+
mopts.what
96+
97+
mopts.lrate = 0.1
98+
mopts.reg1weight = regParam
99+
mopts.batchSize = 1000
100+
mopts.npasses = 250
101+
mopts.autoReset = false
102+
mopts.addConstFeat = true
103+
mm.train
104+
mm
105+
}
106+
107+
def getFMatsFromData(lab: String, data:DataFrame): (FMat, SMat) = {
108+
val features = data.select(lab, "features")
109+
110+
val sparseVectorsWithRowIndices = (for (r <- features) yield (r.getAs[SparseVector](1), r.getAs[Double](0))).zipWithIndex
111+
112+
val triples = for {
113+
((vector, innerLabel), rowIndex) <- sparseVectorsWithRowIndices
114+
(index, value) <- vector.indices zip vector.values
115+
} yield ((rowIndex.toInt,index,value), innerLabel)
116+
117+
val catTriples = for {
118+
((vector, innerLabel), rowIndex) <- sparseVectorsWithRowIndices
119+
} yield (rowIndex.toInt,innerLabel.toInt,1.0)
120+
121+
val cats = catTriples
122+
val feats = triples.map(x => x._1)
123+
124+
val numRows = cats.count().toInt
125+
126+
val catsMat = loadFMatTxt(cats,numRows)
127+
128+
val featsMat = loadFMatTxt(feats,numRows)
129+
130+
println(featsMat)
131+
132+
(full(catsMat), featsMat)
133+
}
134+
135+
//See https://github.com/BIDData/BIDMat/blob/master/src/main/scala/BIDMat/HMat.scala , method loadDMatTxt
136+
def loadFMatTxt(cats:RDD[(Int,Int,Double)], nrows: Int):SMat = {
137+
138+
val rows = cats.map(x=> x._1).collect()
139+
val cols = cats.map(x=> x._2).collect()
140+
val vals = cats.map(x=> x._3).collect()
141+
142+
143+
println("LOADING")
144+
145+
sparse(icol(cols.toList),icol(rows.toList),col(vals.toList))
146+
}
147+
148+
def test(categories: DMat, features: SMat, mm: Learner): Unit = {
149+
val testdata = features
150+
val testcats = categories
151+
152+
//min(testdata, 1, testdata)
153+
154+
val predcats = zeros(testcats.nrows, testcats.ncols)
155+
156+
157+
158+
val (nn, nopts) = GLM.predictor(mm.model, testdata, predcats)
159+
160+
161+
162+
nopts.addConstFeat = true
163+
nn.predict
164+
165+
166+
computeAccuracy(FMat(testcats), predcats)
167+
}
168+
169+
def computeAccuracy(testcats: FMat, predcats: FMat): Unit = {
170+
//println(testcats)
171+
//println(predcats)
172+
173+
val lacc = (predcats ∙→ testcats + (1 - predcats) ∙→ (1 - testcats)) / predcats.ncols
174+
lacc.t
175+
println(mean(lacc))
176+
}
177+
178+
}

src/main/scala/org/template/textclassification/DataSource.scala

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ import org.apache.spark.rdd.RDD
1717
// cross validation.
1818

1919
case class DataSourceParams(
20-
appName: String,
21-
evalK: Option[Int]
22-
) extends Params
20+
appName: String,
21+
evalK: Option[Int]
22+
) extends Params
2323

2424

2525

@@ -28,8 +28,8 @@ case class DataSourceParams(
2828
// readEval method.
2929

3030
class DataSource (
31-
val dsp : DataSourceParams
32-
) extends PDataSource[TrainingData, EmptyEvaluationInfo, Query, ActualResult] {
31+
val dsp : DataSourceParams
32+
) extends PDataSource[TrainingData, EmptyEvaluationInfo, Query, ActualResult] {
3333

3434
@transient lazy val logger = Logger[this.type]
3535

@@ -39,15 +39,15 @@ class DataSource (
3939
//Get RDD of Events.
4040
PEventStore.find(
4141
appName = dsp.appName,
42-
entityType = Some("content"), // specify data entity type
43-
eventNames = Some(List("e-mail")) // specify data event name
42+
entityType = Some("user"), // specify data entity type
43+
eventNames = Some(List("$set")) // specify data event name
4444

4545
// Convert collected RDD of events to and RDD of Observation
4646
// objects.
4747
)(sc).map(e => {
4848
val label : String = e.properties.get[String]("label")
4949
Observation(
50-
if (label == "spam") 1.0 else 0.0,
50+
if (label == "1") 1.0 else 0.0,
5151
e.properties.get[String]("text"),
5252
label
5353
)
@@ -62,7 +62,7 @@ class DataSource (
6262
entityType = Some("resource"),
6363
eventNames = Some(List("stopwords"))
6464

65-
//Convert collected RDD of strings to a string set.
65+
//Convert collected RDD of strings to a string set.
6666
)(sc)
6767
.map(e => e.properties.get[String]("word"))
6868
.collect
@@ -92,7 +92,7 @@ class DataSource (
9292
val train = new TrainingData(
9393
data.filter(_._2 % dsp.evalK.get != k).map(_._1),
9494
readStopWords
95-
((sc)))
95+
((sc)))
9696

9797
// Prepare test data for fold.
9898
val test = data.filter(_._2 % dsp.evalK.get == k)
@@ -108,17 +108,17 @@ class DataSource (
108108
// 3. Observation class serving as a wrapper for both our
109109
// data's class label and document string.
110110
case class Observation(
111-
label : Double,
112-
text : String,
113-
category :String
114-
) extends Serializable
111+
label : Double,
112+
text : String,
113+
category :String
114+
) extends Serializable
115115

116116
// 4. TrainingData class serving as a wrapper for all
117117
// read in from the Event Server.
118118
class TrainingData(
119-
val data : RDD[Observation],
120-
val stopWords : Set[String]
121-
) extends Serializable with SanityCheck {
119+
val data : RDD[Observation],
120+
val stopWords : Set[String]
121+
) extends Serializable with SanityCheck {
122122

123123
// Sanity check to make sure your data is being fed in correctly.
124124

src/main/scala/org/template/textclassification/Engine.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,10 @@ object TextClassificationEngine extends EngineFactory {
4040
classOf[DataSource],
4141
classOf[Preparator],
4242
Map(
43+
"VWlogisticSGD" -> classOf[VowpalLogisticRegressionWithSGDAlgorithm],
4344
"nb" -> classOf[NBAlgorithm],
44-
"lr" -> classOf[LRAlgorithm]
45+
"lr" -> classOf[LRAlgorithm],
46+
"bid-lr" -> classOf[BIDMachLRAlgorithm]
4547
), classOf[Serving]
4648
)
4749
}

0 commit comments

Comments
 (0)