Skip to content

Commit

Permalink
Minimal KNN
Browse files Browse the repository at this point in the history
  • Loading branch information
eriklindernoren committed Jan 22, 2018
1 parent 193032c commit da90aa1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 18 deletions.
26 changes: 9 additions & 17 deletions mlfromscratch/supervised_learning/k_nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,21 @@ class KNN():
def __init__(self, k=5):
self.k = k

def _vote(self, neighbors):
def _vote(self, neighbor_labels):
""" Return the most common class among the neighbor samples """
counts = np.bincount(neighbors[:, 1].astype('int'))
counts = np.bincount(neighbor_labels.astype('int'))
return counts.argmax()

def predict(self, X_test, X_train, y_train):
y_pred = np.empty(X_test.shape[0])
# Determine the class of each sample
for i, test_sample in enumerate(X_test):
# Two columns [distance, label], for each observed sample
neighbors = np.empty((X_train.shape[0], 2))
# Calculate the distance from each observed sample to the
# sample we wish to predict
for j, observed_sample in enumerate(X_train):
distance = euclidean_distance(test_sample, observed_sample)
label = y_train[j]
# Add neighbor information
neighbors[j] = [distance, label]
# Sort the list of observed samples from lowest to highest distance
# and select the k first
k_nearest_neighbors = neighbors[neighbors[:, 0].argsort()][:self.k]
# Get the most common class among the neighbors
label = self._vote(k_nearest_neighbors)
y_pred[i] = label
# Sort the training samples by their distance to the test sample and get the K nearest
idx = np.argsort([euclidean_distance(test_sample, x) for x in X_train])[:self.k]
# Extract the labels of the K nearest neighboring training samples
k_nearest_neighbors = np.array([y_train[i] for i in idx])
# Label sample as the most common class label
y_pred[i] = self._vote(k_nearest_neighbors)

return y_pred

2 changes: 1 addition & 1 deletion mlfromscratch/unsupervised_learning/dcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,6 @@ def save_imgs(self, epoch):

if __name__ == '__main__':
dcgan = DCGAN()
dcgan.train(epochs=200000, batch_size=32, save_interval=50)
dcgan.train(epochs=200000, batch_size=64, save_interval=50)


0 comments on commit da90aa1

Please sign in to comment.