Skip to content

Commit bb40ef0

Browse files
committed
Add average perceptron implementation
1 parent f0e2b7e commit bb40ef0

File tree

9 files changed

+613
-16
lines changed

9 files changed

+613
-16
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
defmodule Classifiers.Perceptron.Average do
2+
defstruct weights: %{},
3+
edges: %{},
4+
count: 0,
5+
epoch: 0
6+
7+
@doc """
8+
Get a new classifier pid.
9+
"""
10+
def new do
11+
{:ok, pid} = Agent.start_link fn ->
12+
%Classifiers.Perceptron.Average{}
13+
end
14+
15+
pid
16+
end
17+
18+
@doc """
19+
Fit a stream of data to an existing classifier.
20+
Currently expects input in the form of a stream of maps as the following:
21+
[ feature_1, feature_2, ... feature_n, class ]
22+
"""
23+
def fit(stream, pid) do
24+
stream |> Stream.chunk(10) |> Enum.each fn chunk ->
25+
Agent.get_and_update pid, fn classifier ->
26+
c = chunk |> Enum.reduce classifier, fn row, classifier ->
27+
label = row |> List.last
28+
features = row |> Enum.drop(-1)
29+
|> Enum.with_index
30+
|> Enum.map(fn {a, b} -> {a,b} end)
31+
32+
classifier = case classifier |> make_prediction(features, true) do
33+
nil ->
34+
%{
35+
classifier | edges: classifier.edges |> Map.put(
36+
label, features |> Enum.into(%{}, &({&1, 1}))
37+
)
38+
}
39+
^label ->
40+
classifier
41+
prediction ->
42+
%{
43+
classifier | edges: classifier.edges |> Map.update(
44+
label, %{}, fn current ->
45+
features |> Enum.reduce(
46+
current, fn feature, current ->
47+
current |> Map.update(feature, 0, &(&1 + 1))
48+
end
49+
)
50+
end
51+
) |> Map.update(
52+
prediction, %{}, fn current ->
53+
features |> Enum.reduce(
54+
current, fn feature, current ->
55+
current |> Map.update(feature, 0, &(&1 - 1))
56+
end
57+
)
58+
end
59+
)
60+
}
61+
end
62+
63+
%{ classifier |
64+
count: classifier.count + 1,
65+
weights: classifier.edges |> Enum.reduce(
66+
classifier.weights, fn { label, edges }, weights ->
67+
target = weights |> Map.get(label, %{})
68+
target = edges |> Enum.reduce(target, fn { feature, edge }, target ->
69+
target |> Map.update(feature, 0, fn weight ->
70+
(classifier.count * weight + edge) / (classifier.count + 1)
71+
end)
72+
end)
73+
74+
weights |> Map.update(label, %{}, fn w -> w |> Map.merge(target) end)
75+
end
76+
)
77+
}
78+
end
79+
80+
{:ok, c}
81+
end
82+
end
83+
end
84+
85+
@doc """
86+
Predict the class for one set of features.
87+
"""
88+
def predict_one(features, pid) do
89+
end
90+
91+
@doc """
92+
Predict the classes for a stream of features
93+
"""
94+
def predict(stream, pid) do
95+
c = classifier(pid)
96+
stream |> Stream.transform(0, fn row, acc ->
97+
features = row |> Enum.with_index |> Enum.map(fn {a, b} -> {a, b} end)
98+
99+
{ [ c |> make_prediction(features, false) ], acc + 1 }
100+
end)
101+
end
102+
103+
defp make_prediction(%{edges: edges}, features, true) when map_size(edges) == 0 do
104+
end
105+
defp make_prediction(%{edges: edges}, features, true) do
106+
{p, _} = edges |> Enum.max_by fn { label, edge } ->
107+
features |> Enum.reduce(0, fn feature, weight -> weight + Map.get(edge, feature, 0) end)
108+
end
109+
110+
p
111+
end
112+
defp make_prediction(%{weights: weights}, features, false) do
113+
{p, _} = weights |> Enum.max_by fn { label, weight } ->
114+
features |> Enum.reduce(0, fn feature, w -> w + Map.get(weight, feature, 0) end)
115+
end
116+
117+
p
118+
end
119+
120+
defp classifier(pid) do
121+
Agent.get pid, fn c -> c end
122+
end
123+
124+
end

mix.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ defmodule Classifiers.Mixfile do
2626

2727
defp deps do
2828
[
29-
{:csv, "~> 0.2.0", only: :test},
29+
{:csv, "~> 1.0.0", only: :test},
3030
{:ex_doc, "~> 0.7.1", only: :docs},
3131
{:inch_ex, only: :docs},
3232
{:earmark, only: :docs}

mix.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
%{"csv": {:hex, :csv, "0.2.2"},
1+
%{"csv": {:hex, :csv, "1.0.0"},
22
"earmark": {:hex, :earmark, "0.1.17"},
33
"ex_doc": {:hex, :ex_doc, "0.7.3"},
44
"inch_ex": {:hex, :inch_ex, "0.3.1"},

test/classifiers/naive_bayes/bernoulli_test.exs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,29 +31,29 @@ defmodule ClassifiersTest.NaiveBayes.Bernoulli do
3131
assert conditional_probabilities["positive"] == [1.0, 1.0, 1.0/3.0]
3232
end
3333

34-
test "pnegativeict one works correctly", context do
34+
test "predict one works correctly", context do
3535
classifier = context[:classifier]
3636

37-
pnegativeiction = Classifiers.NaiveBayes.Bernoulli.predict_one([0,1,0], classifier)
38-
assert pnegativeiction == "positive"
37+
prediction = Classifiers.NaiveBayes.Bernoulli.predict_one([0,1,0], classifier)
38+
assert prediction == "positive"
3939

40-
pnegativeiction = Classifiers.NaiveBayes.Bernoulli.predict_one([1,1,0], classifier)
41-
assert pnegativeiction == "positive"
40+
prediction = Classifiers.NaiveBayes.Bernoulli.predict_one([1,1,0], classifier)
41+
assert prediction == "positive"
4242

43-
pnegativeiction = Classifiers.NaiveBayes.Bernoulli.predict_one([1,0,1], classifier)
44-
assert pnegativeiction == "negative"
43+
prediction = Classifiers.NaiveBayes.Bernoulli.predict_one([1,0,1], classifier)
44+
assert prediction == "negative"
4545

46-
pnegativeiction = Classifiers.NaiveBayes.Bernoulli.predict_one([0,0,1], classifier)
47-
assert pnegativeiction == "negative"
46+
prediction = Classifiers.NaiveBayes.Bernoulli.predict_one([0,0,1], classifier)
47+
assert prediction == "negative"
4848
end
4949

50-
test "pnegativeict works correctly", context do
51-
pnegativeictions = "naive_bayesian_test.csv"
50+
test "predict works correctly", context do
51+
predictions = "naive_bayesian_test.csv"
5252
|> Fixture.csv
5353
|> Classifiers.NaiveBayes.Bernoulli.predict(context[:classifier])
5454
|> Enum.to_list
5555

56-
assert pnegativeictions == ["positive", "positive", "negative", "negative"]
56+
assert predictions == ["positive", "positive", "negative", "negative"]
5757
end
5858

5959
end
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
defmodule ClassifiersTest.Perceptron.Average do
2+
use ExUnit.Case
3+
4+
setup do
5+
classifier = Classifiers.Perceptron.Average.new
6+
7+
"average_perceptron_train.csv"
8+
|> Fixture.csv(num_pipes: 1)
9+
|> Classifiers.Perceptron.Average.fit(classifier)
10+
11+
{:ok, classifier: classifier}
12+
end
13+
14+
defp get_classifier(pid) do
15+
Agent.get pid, fn c -> c end
16+
end
17+
18+
test "fitting generates averaged weights for the given features", context do
19+
%{ weights: %{ "democrat" => w } } = context[:classifier] |> get_classifier
20+
assert w |> Map.size == 48
21+
assert w |> Map.values |> Enum.sum |> Float.round(3) == 0.367
22+
end
23+
24+
test "fitting generates edges for the given features", context do
25+
%{ edges: %{ "republican" => e } } = context[:classifier] |> get_classifier
26+
assert e |> Map.size == 48
27+
assert e |> Map.values |> Enum.sum == 8
28+
end
29+
30+
test "predict works correctly", context do
31+
predictions = "average_perceptron_test.csv"
32+
|> Fixture.csv(num_pipes: 1)
33+
|> Classifiers.Perceptron.Average.predict(context[:classifier])
34+
|> Enum.to_list
35+
assert predictions ==
36+
~w(democrat democrat democrat republican democrat democrat republican democrat republican republican)
37+
end
38+
39+
end

test/support/fixture.exs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ defmodule Fixture do
44
Path.join("./fixtures", filename) |> Path.expand(__DIR__)
55
end
66

7-
def csv(filename) do
7+
def csv(filename, options \\ []) do
88
filename
99
|> Fixture.path
1010
|> File.stream!
11-
|> CSV.decode
11+
|> CSV.decode(options)
1212
end
1313

1414
end
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
n,y,y,n,n,?,y,y,y,y,y,n,?,y,y,y
2+
n,n,y,n,n,n,y,y,n,y,y,n,n,n,y,?
3+
y,n,y,n,n,n,y,y,y,y,n,n,n,n,y,y
4+
n,n,n,y,y,y,y,y,n,y,n,y,y,y,n,y
5+
?,?,?,n,n,n,y,y,y,y,n,n,y,n,y,y
6+
y,n,y,n,?,n,y,y,y,y,n,y,n,?,y,y
7+
n,n,y,y,y,y,n,n,y,y,n,y,y,y,n,y
8+
n,n,y,n,n,n,y,y,y,y,n,n,n,n,n,y
9+
n,?,n,y,y,y,n,n,n,n,y,y,y,y,n,y
10+
n,n,n,y,y,y,?,?,?,?,n,y,y,y,n,y

0 commit comments

Comments
 (0)