Skip to content

Commit

Permalink
Merge branch 'dvisockas-master'
Browse files Browse the repository at this point in the history
  • Loading branch information
igrigorik committed Nov 22, 2015
2 parents eaa9a2f + 2155771 commit 15394f8
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 132 deletions.
46 changes: 33 additions & 13 deletions examples/continuous-id3.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,25 @@
require 'decisiontree'
include DecisionTree

# ---Continuous-----------------------------------------------------------------------------------------
# ---Continuous---

# Read in the training data
training, attributes = [], nil
File.open('data/continuous-training.txt','r').each_line { |line|
data = line.strip.chomp('.').split(',')
training = []
File.open('data/continuous-training.txt', 'r').each_line do |line|
data = line.strip.chomp('.').split(',')
attributes ||= data
training.push(data.collect {|v| (v == 'healthy') || (v == 'colic') ? (v == 'healthy' ? 1 : 0) : v.to_f})
}
training_data = data.collect do |v|
case v
when 'healthy'
1
when 'colic'
0
else
v.to_f
end
end
training.push(training_data)
end

# Remove the attribute row from the training data
training.shift
Expand All @@ -19,15 +29,25 @@
dec_tree = ID3Tree.new(attributes, training, 1, :continuous)
dec_tree.train

#---- Test the tree....
# ---Test the tree---

# Read in the test cases
# Note: omit the attribute line (first line), we know the labels from the training data
# Note: omit the attribute line (first line), we know the labels from the training data
test = []
File.open('data/continuous-test.txt','r').each_line { |line|
data = line.strip.chomp('.').split(',')
test.push(data.collect {|v| (v == 'healthy') || (v == 'colic') ? (v == 'healthy' ? 1 : 0) : v.to_f})
}
File.open('data/continuous-test.txt', 'r').each_line do |line|
data = line.strip.chomp('.').split(',')
test_data = data.collect do |v|
if v == 'healthy' || v == 'colic'
v == 'healthy' ? 1 : 0
else
v.to_f
end
end
test.push(test_data)
end

# Let the tree predict the output and compare it to the true specified value
test.each { |t| predict = dec_tree.predict(t); puts "Predict: #{predict} ... True: #{t.last}"}
test.each do |t|
predict = dec_tree.predict(t)
puts "Predict: #{predict} ... True: #{t.last}"
end
48 changes: 36 additions & 12 deletions examples/discrete-id3.rb
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
require 'rubygems'
require 'decisiontree'

# ---Discrete-----------------------------------------------------------------------------------------
# ---Discrete---

# Read in the training data
training, attributes = [], nil
File.open('data/discrete-training.txt','r').each_line { |line|
training = []
File.open('data/discrete-training.txt', 'r').each_line do |line|
data = line.strip.split(',')
attributes ||= data
training.push(data.collect {|v| (v == 'will buy') || (v == "won't buy") ? (v == 'will buy' ? 1 : 0) : v})
}
training_data = data.collect do |v|
case v
when 'will buy'
1
when "won't buy"
0
else
v
end
end
training.push(training_data)
end

# Remove the attribute row from the training data
training.shift
Expand All @@ -18,17 +28,31 @@
dec_tree = DecisionTree::ID3Tree.new(attributes, training, 1, :discrete)
dec_tree.train

#---- Test the tree....
# ---Test the tree---

# Read in the test cases
# Note: omit the attribute line (first line), we know the labels from the training data
# Note: omit the attribute line (first line), we know the labels from the training data
test = []
File.open('data/discrete-test.txt','r').each_line { |line| data = line.strip.split(',')
test.push(data.collect {|v| (v == 'will buy') || (v == "won't buy") ? (v == 'will buy' ? 1 : 0) : v})
}
File.open('data/discrete-test.txt', 'r').each_line do |line|
data = line.strip.split(',')
test_data = data.collect do |v|
case v
when 'will buy'
1
when "won't buy"
0
else
v
end
end
training.push(test_data)
end

# Let the tree predict the output and compare it to the true specified value
test.each { |t| predict = dec_tree.predict(t); puts "Predict: #{predict} ... True: #{t.last}"; }
test.each do |t|
predict = dec_tree.predict(t)
puts "Predict: #{predict} ... True: #{t.last}"
end

# Graph the tree, save to 'discrete.png'
dec_tree.graph("discrete")
dec_tree.graph('discrete')
16 changes: 7 additions & 9 deletions examples/simple.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,25 @@

require 'rubygems'
require 'decisiontree'

attributes = ['Temperature']
training = [
[36.6, 'healthy'],
[37, 'sick'],
[38, 'sick'],
[36.7, 'healthy'],
[40, 'sick'],
[50, 'really sick'],
[50, 'really sick']
]

# Instantiate the tree, and train it based on the data (set default to '1')
dec_tree = DecisionTree::ID3Tree.new(attributes, training, 'sick', :continuous)
dec_tree.train

test = [37, 'sick']

decision = dec_tree.predict(test)
puts "Predicted: #{decision} ... True decision: #{test.last}";

# Graph the tree, save to 'tree.png'
dec_tree.graph("tree")

decision = dec_tree.predict(test)
puts "Predicted: #{decision} ... True decision: #{test.last}"

# Graph the tree, save to 'tree.png'
dec_tree.graph('tree')
29 changes: 29 additions & 0 deletions lib/core_extensions/array.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
class Array
def classification
collect(&:last)
end

# calculate information entropy
def entropy
return 0 if empty?

info = {}
each do |i|
info[i] = !info[i] ? 1 : (info[i] + 1)
end

result(info, length)
end

private

def result(info, total)
final = 0
info.each do |_symbol, count|
next unless count > 0
percentage = count.to_f / total
final += -percentage * Math.log(percentage) / Math.log(2.0)
end
final
end
end
9 changes: 9 additions & 0 deletions lib/core_extensions/object.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class Object
def save_to_file(filename)
File.open(filename, 'w+') { |f| f << Marshal.dump(self) }
end

def self.load_from_file(filename)
Marshal.load(File.read(filename))
end
end
2 changes: 2 additions & 0 deletions lib/decisiontree.rb
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
require File.dirname(__FILE__) + '/decisiontree/id3_tree.rb'
require 'core_extensions/object'
require 'core_extensions/array'
Loading

0 comments on commit 15394f8

Please sign in to comment.