-
Notifications
You must be signed in to change notification settings - Fork 0
/
KNN.py
83 lines (62 loc) · 1.89 KB
/
KNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
'''
AUTHOR:
Thiago Alexandre Domingues de Souza
LANGUAGE:
Python3
INPUT:
The input file should have the following format. The first line consists of an integer representing the k nearest neighbors.
On the second line, two coordinates x and y represent the test data. On the third line, an integer n represents
number of items in the training set. Then follow n lines, each of them consisting of two numbers xt and yt, representing
the training coordinates and the corresponding category.
OUTPUT:
A single line representing the most voted class around the k-nearest neighbors.
FORMAT:
k
x y
n
xt yt category
SAMPLE INPUT:
2
250 250
8
1 2 one
1 3 one
1 5 one
1 6 one
1 7 one
1 8 one
200 200 two
200 201 two
SAMPLE OUTPUT:
two
'''
class KNN():
def squared_distance(self, a, b):
return (a[0] - b[0]) * (a[0] - b[0]) + (a[1] - b[1]) * (a[1] - b[1])
def classify(self, k, training, test):
# sorting training set by distance
# training = ((x, y), category)
sorted_by_distance = sorted(training, key=lambda p: self.squared_distance(p[0], test))
# class_votes is a dictionary of {category, frequency}
class_votes = {}
# counting votes from k nearest neighbors
for val in sorted_by_distance[:k]:
if val[1] in class_votes:
class_votes[val[1]] += 1
else:
class_votes[val[1]] = 1
# sorting class_votes, so more voted classes appear first
result = sorted(class_votes, key=class_votes.get, reverse=True)
# return the first most voted class
return result[0]
k = int(input())
x, y = map(float, input().split())
n = int(input())
training = []
for i in range(0, n):
xt, yt, category = input().split()
training.append(((float(xt), float(yt)), category))
knn = KNN()
print("Test example:", (x,y))
print("k:", k)
print("Category:", knn.classify(k, training, (x,y)))