Skip to content

Commit

Permalink
Making code easier to read
Browse files Browse the repository at this point in the history
  • Loading branch information
Wise Monk committed Jun 26, 2015
1 parent eaa9a2f commit 310b73b
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions lib/decisiontree/id3_tree.rb
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@ def self.load_from_file(filename)
end

class Array
def classification; collect { |v| v.last }; end
def classification
collect { |v| v.last }
end

# calculate information entropy
def entropy
return 0 if empty?

info = {}
total = 0
each {|i| info[i] = !info[i] ? 1 : (info[i] + 1); total += 1}
each { |i| info[i] = !info[i] ? 1 : (info[i] + 1); total += 1}

result = 0
info.each do |symbol, count|
Expand All @@ -42,11 +44,11 @@ def initialize(attributes, data, default, type)
end

def train(data=@data, attributes=@attributes, default=@default)
attributes = attributes.map {|e| e.to_s}
attributes = attributes.map { |e| e.to_s}
initialize(attributes, data, default, @type)

# Remove samples with same attributes leaving most common classification
data2 = data.inject({}) {|hash, d| hash[d.slice(0..-2)] ||= Hash.new(0); hash[d.slice(0..-2)][d.last] += 1; hash }.map{|key,val| key + [val.sort_by{ |k, v| v }.last.first]}
data2 = data.inject({}) { |hash, d| hash[d.slice(0..-2)] ||= Hash.new(0); hash[d.slice(0..-2)][d.last] += 1; hash }.map{ |key,val| key + [val.sort_by{ |k, v| v }.last.first]}

@tree = id3_train(data2, attributes, default)
end
Expand All @@ -57,8 +59,8 @@ def type(attribute)

def fitness_for(attribute)
case type(attribute)
when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)}
when :continuous; fitness = proc{|a,b,c| id3_continuous(a,b,c)}
when :discrete; fitness = proc { |a,b,c| id3_discrete(a,b,c) }
when :continuous; fitness = proc { |a,b,c| id3_continuous(a,b,c) }
end
end

Expand Down Expand Up @@ -122,7 +124,7 @@ def id3_continuous(data, attributes, attribute)
def id3_discrete(data, attributes, attribute)
values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort
partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } }
remainder = partitions.collect {|p| (p.size.to_f / data.size) * p.classification.entropy}.inject(0) {|i,s| s+=i }
remainder = partitions.collect { |p| (p.size.to_f / data.size) * p.classification.entropy}.inject(0) { |i,s| s+=i }

[data.classification.entropy - remainder, attributes.index(attribute)]
end
Expand Down Expand Up @@ -281,7 +283,7 @@ def prune(data=@prune_data)
end
end
end
@rules = @rules.sort_by{|r| -r.accuracy(data)}
@rules = @rules.sort_by{ |r| -r.accuracy(data) }
end

def to_s
Expand Down Expand Up @@ -320,7 +322,7 @@ def predict(test)
predictions[p] += accuracy unless p.nil?
end
return @default, 0.0 if predictions.empty?
winner = predictions.sort_by {|k,v| -v}.first
winner = predictions.sort_by { |k,v| -v}.first
return winner[0], winner[1].to_f / @classifiers.size.to_f
end
end
Expand Down

0 comments on commit 310b73b

Please sign in to comment.