Skip to content

Commit

Permalink
Add K-Means.
Browse files Browse the repository at this point in the history
  • Loading branch information
trekhleb committed Dec 21, 2018
1 parent 23a5165 commit 1bd9e3d
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 51 deletions.
85 changes: 75 additions & 10 deletions homemade/k_means/k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,48 @@ class KMeans:
"""K-Means Class"""

def __init__(self, data, num_clusters):
"""K-Means class constructor.
:param data: training dataset.
:param num_clusters: number of cluster into which we want to break the dataset.
"""
self.data = data
self.num_clusters = num_clusters

def train(self, max_iterations):
centroids = KMeans.init_centroids(self.data, self.num_clusters)
"""Function performs data clustering using K-Means algorithm
:param max_iterations: maximum number of training iterations.
"""

# Generate random centroids based on training set.
centroids = KMeans.centroids_init(self.data, self.num_clusters)

# Init default array of closest centroid IDs.
num_examples = self.data.shape[0]
closest_centroids_ids = np.empty((num_examples, 1))

# Run K-Means.
for iteration_index in range(max_iterations):
# Find the closest centroids for training examples.
closest_centroids_ids = KMeans.find_closest_centroids(self.data, centroids)
closest_centroids_ids = KMeans.centroids_find_closest(self.data, centroids)

# Compute means based on the closest centroids found in the previous part.
centroids = KMeans.centroids_compute(
self.data,
closest_centroids_ids,
self.num_clusters
)

return centroids, closest_centroids_ids

@staticmethod
def init_centroids(data, num_clusters):
"""Initializes num_clusters centroids that are to be used in K-Means on the dataset data"""
def centroids_init(data, num_clusters):
"""Initializes num_clusters centroids that are to be used in K-Means on the dataset X
:param data: training dataset.
:param num_clusters: number of cluster into which we want to break the dataset.
"""

# Get number of training examples.
num_examples = data.shape[0]
Expand All @@ -27,19 +55,28 @@ def init_centroids(data, num_clusters):
random_ids = np.random.permutation(num_examples)

# Take the first K examples as centroids.
centroids = data[random_ids[:num_clusters + 1], :]
centroids = data[random_ids[:num_clusters], :]

# Return generated centroids.
return centroids

@staticmethod
def find_closest_centroids(data, centroids):
def centroids_find_closest(data, centroids):
"""Computes the centroid memberships for every example.
Returns the closest centroids in closest_centroids_ids for a dataset X where each row is
a single example. closest_centroids_ids = m x 1 vector of centroid assignments (i.e. each
entry in range [1..K]).
:param data: training dataset.
:param centroids: list of centroid points.
"""

# Get number of training examples.
num_examples = data.shape[0]

# Get number of centroids.
num_centroids = centroids.shape[0]
print(num_centroids)

# We need to return the following variables correctly.
closest_centroids_ids = np.zeros((num_examples, 1))
Expand All @@ -53,6 +90,34 @@ def find_closest_centroids(data, centroids):
for centroid_index in range(num_centroids):
distance_difference = data[example_index, :] - centroids[centroid_index, :]
distances[centroid_index] = np.sum(distance_difference ** 2)
print(distances)
print(np.argmin(distances))
break
closest_centroids_ids[example_index] = np.argmin(distances)

return closest_centroids_ids

@staticmethod
def centroids_compute(data, closest_centroids_ids, num_clusters):
"""Compute new centroids.
Returns the new centroids by computing the means of the data points assigned to
each centroid.
:param data: training dataset.
:param closest_centroids_ids: list of closest centroid ids per each training example.
:param num_clusters: number of clusters.
"""

# Get number of training examples and features.
(num_examples, num_features) = data.shape

# We need to return the following variables correctly.
centroids = np.zeros((num_clusters, num_features))

# Go over every centroid and compute mean of all points that
# belong to it. Concretely, the row vector centroids(i, :)
# should contain the mean of the data points assigned to
# centroid i.
for centroid_id in range(num_clusters):
closest_ids = closest_centroids_ids == centroid_id
centroids[centroid_id] = np.mean(data[closest_ids.flatten(), :], axis=0)

return centroids
Loading

0 comments on commit 1bd9e3d

Please sign in to comment.