Skip to content

Commit 86d5fca

Browse files
committed
Added some text.
1 parent 6dadb19 commit 86d5fca

File tree

1 file changed

+81
-45
lines changed

1 file changed

+81
-45
lines changed

main.md

Lines changed: 81 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ The answer lies in the observation that many real-world datasets have a low intr
1010

1111
This is the topic of [**manifold learning**](http://en.wikipedia.org/wiki/Nonlinear_dimensionality_reduction), also called **nonlinear dimensionality reduction**, a branch of machine learning (more specifically, _unsupervised learning_). It is still an active area of research today to develop algorithms that can automatically recover a hidden structure in a high-dimensional dataset.
1212

13-
This post is an introduction to a popular dimensonality reduction algorithm: [**t-distributed stochastic neighbor embedding (t-SNE)**](http://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding). Developed by [Laurens van der Maaten](http://lvdmaaten.github.io/) and [Geoffrey Hinton](http://www.cs.toronto.edu/~hinton/), this algorithm has been successfully applied to many real-world datasets. Here, we'll see the key concepts of the method, when applied to a toy dataset (handwritten digits). We'll use Python and the [scikit-learn](http://scikit-learn.org/stable/index.html) library.
13+
This post is an introduction to a popular dimensonality reduction algorithm: [**t-distributed stochastic neighbor embedding (t-SNE)**](http://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding). Developed by [Laurens van der Maaten](http://lvdmaaten.github.io/) and [Geoffrey Hinton](http://www.cs.toronto.edu/~hinton/), this algorithm has been successfully applied to many real-world datasets. Here, we'll follow the original paper and describe the key mathematical concepts of the method, when applied to a toy dataset (handwritten digits). We'll use Python and the [scikit-learn](http://scikit-learn.org/stable/index.html) library.
1414

1515
## Visualizing handwritten digits.
1616

17-
Let's first import a handful of libraries.
17+
Let's first import a few libraries.
1818

1919
<pre data-code-language="python"
2020
data-executable="true"
@@ -55,25 +55,46 @@ from moviepy.video.io.bindings import mplfig_to_npimage
5555
import moviepy.editor as mpy
5656
</pre>
5757

58-
58+
Now we load the classic *handwritten digits* datasets. It contains 1797 images with <span class="math-tex" data-type="tex">\\(8*8=64\\)</span> pixels each.
5959

6060
<pre data-code-language="python"
6161
data-executable="true"
6262
data-type="programlisting">
6363
digits = load_digits()
64+
digits.data.shape
65+
</pre>
66+
67+
<pre data-code-language="python"
68+
data-executable="true"
69+
data-type="programlisting">
70+
print(digits['DESCR'])
71+
</pre>
72+
73+
Here are the images:
74+
75+
<pre data-code-language="python"
76+
data-executable="true"
77+
data-type="programlisting">
78+
nrows, ncols = 2, 5
79+
plt.figure(figsize=(6,3))
80+
plt.gray()
81+
for i in range(ncols * nrows):
82+
ax = plt.subplot(nrows, ncols, i + 1)
83+
ax.matshow(digits.images[i,...])
84+
plt.xticks([]); plt.yticks([])
85+
plt.title(digits.target[i])
6486
</pre>
6587

66-
TODO
67-
(detail the dataset, nsamples, ndimensions)
68-
(final output of tSNE)
88+
Now let's run the t-SNE algorithm on the dataset. It just take one line with scikit-learn.
6989

7090
<pre data-code-language="python"
7191
data-executable="true"
7292
data-type="programlisting">
73-
tsne = TSNE()
74-
digits_proj = tsne.fit_transform(digits.data)
93+
digits_proj = TSNE().fit_transform(digits.data)
7594
</pre>
7695

96+
Here is a utility function used to display the transformed dataset.
97+
7798
<pre data-code-language="python"
7899
data-executable="true"
79100
data-type="programlisting">
@@ -83,10 +104,10 @@ def scatter(x, colors):
83104
f = plt.figure(figsize=(8, 8))
84105
ax = plt.subplot(aspect='equal')
85106
sc = ax.scatter(x[:,0], x[:,1], lw=0, s=40,
86-
c=palette[colors.astype(np.int)]);
87-
plt.xlim(-25, 25);
88-
plt.ylim(-25, 25);
89-
ax.axis('off');
107+
c=palette[colors.astype(np.int)])
108+
plt.xlim(-25, 25)
109+
plt.ylim(-25, 25)
110+
ax.axis('off')
90111

91112
txts = []
92113
for i in range(10):
@@ -108,9 +129,9 @@ scatter(digits_proj, digits.target);
108129

109130
## Mathematical framework
110131

111-
Let's explain how the algorithm works. First, a few definitions.
132+
Now, let's explain how the algorithm works. First, a few definitions.
112133

113-
A **data point** is a point <span class="math-tex" data-type="tex">\\(x_i\\)</span> in the original **data space** <span class="math-tex" data-type="tex">\\(\mathbf{R}^D\\)</span>, where <span class="math-tex" data-type="tex">\\(D\\)</span> is the **dimensionality** of the data space. Every point is an image of a handwritten digit here. There are <span class="math-tex" data-type="tex">\\(N\\)</span> points.
134+
A **data point** is a point <span class="math-tex" data-type="tex">\\(x_i\\)</span> in the original **data space** <span class="math-tex" data-type="tex">\\(\mathbf{R}^D\\)</span>, where <span class="math-tex" data-type="tex">\\(D=64\\)</span> is the **dimensionality** of the data space. Every point is an image of a handwritten digit here. There are <span class="math-tex" data-type="tex">\\(N=1797\\)</span> points.
114135

115136
A **map point** is a point <span class="math-tex" data-type="tex">\\(y_i\\)</span> in the **map space** <span class="math-tex" data-type="tex">\\(\mathbf{R}^2\\)</span>. This space will contain our final representation of the dataset. There is a _bijection_ between the data points and the map points: every map point represents one of the original images.
116137

@@ -124,7 +145,9 @@ Now, we define the similarity as a symmetrized version of the conditional simila
124145

125146
<span class="math-tex" data-type="tex">\\(p_{ij} = \frac{p_{j|i} + p_{i|j}}{2N}\\)</span>
126147

127-
We obtain a similarity matrix for our original dataset.
148+
We obtain a similarity matrix for our original dataset. What does this matrix look like?
149+
150+
We first reorder the data points according to the handwritten number.
128151

129152
<pre data-code-language="python"
130153
data-executable="true"
@@ -135,11 +158,7 @@ y = np.hstack([digits.target[digits.target==i]
135158
for i in range(10)])
136159
</pre>
137160

138-
<pre data-code-language="python"
139-
data-executable="true"
140-
data-type="programlisting">
141-
distances = pairwise_distances(X, squared=True)
142-
</pre>
161+
The following function computes the similarity with a constant sigma.
143162

144163
<pre data-code-language="python"
145164
data-executable="true"
@@ -150,15 +169,21 @@ def _joint_probabilities_constant_sigma(D, sigma):
150169
return P
151170
</pre>
152171

172+
We now compute the similarity with a sigma depending on the data point (found via a binary search). This algorith is implemented in scikit-learn's `_joint_probabilities` function.
173+
153174
<pre data-code-language="python"
154175
data-executable="true"
155176
data-type="programlisting">
177+
# Pairwise distances between all data points.
156178
D = pairwise_distances(X, squared=True)
157179
P_constant = _joint_probabilities_constant_sigma(D, .002)
158180
P_binary = _joint_probabilities(D, 30., False)
181+
# The output of this function needs to be reshaped to a square matrix.
159182
P_binary_s = squareform(P_binary)
160183
</pre>
161184

185+
Let's display the distance matrix of the data points, and the similarity matrix with both a constant and variable sigma.
186+
162187
<pre data-code-language="python"
163188
data-executable="true"
164189
data-type="programlisting">
@@ -181,44 +206,48 @@ plt.axis('off')
181206
plt.title("$p_{j|i}$ (binary search sigma)", fontdict={'fontsize': 16});
182207
</pre>
183208

209+
We already observe the 10 groups in the data, corresponding to the 10 numbers.
210+
184211
Let's also define a similarity matrix for our map points.
185212

186-
<span class="math-tex" data-type="tex">\\(q_{ij} = \frac{f(\left| x_i - x_j\right|)}{\displaystyle\sum_{k \neq i} f(\left| x_i - x_k\right|)} \quad \textrm{with} \, f(z) = \frac{1}{1+z^2}.\\)</span>
213+
<span class="math-tex" data-type="tex">\\(q_{ij} = \frac{f(\left| x_i - x_j\right|)}{\displaystyle\sum_{k \neq i} f(\left| x_i - x_k\right|)} \quad \textrm{with} \quad f(z) = \frac{1}{1+z^2}.\\)</span>
187214

188-
This is the same idea as for the data points, but with a different distribution (t-Student, or Cauchy distribution, instead of a Gaussian distribution). We'll elaborate on this choice later.
215+
This is the same idea as for the data points, but with a different distribution ([**t-Student with one degree of freedom**](http://en.wikipedia.org/wiki/Student%27s_t-distribution), or [**Cauchy distribution**](http://en.wikipedia.org/wiki/Cauchy_distribution), instead of a Gaussian distribution). We'll elaborate on this choice later.
189216

190217
Whereas the data similarity matrix <span class="math-tex" data-type="tex">\\(\big(p_{ij}\big)\\)</span> is fixed, the map similarity matrix <span class="math-tex" data-type="tex">\\(\big(q_{ij}\big)\\)</span> depends on the map points. What we want is for these two matrices to be as close as possible. This would mean that similar data points yield similar map points.
191218

192219
## A physical analogy
193220

194-
Let's assume that our map points are all connected with springs. The stiffness of a spring connecting points <span class="math-tex" data-type="tex">\\(i\\)</span> and <span class="math-tex" data-type="tex">\\(j\\)</span> depends on the mismatch between the similarity of the two data points and the similarity of the two map points, that is, <span class="math-tex" data-type="tex">\\(p_{ij} - q_{ij}\\)</span>. Now, we let the system evolve according to the law of physics. If two map points are far apart while the data points are close, they are attracted together. If they are close while the data points are dissimilar, they are repelled.
221+
Let's assume that our map points are all connected with springs. The stiffness of a spring connecting points <span class="math-tex" data-type="tex">\\(i\\)</span> and <span class="math-tex" data-type="tex">\\(j\\)</span> depends on the mismatch between the similarity of the two data points and the similarity of the two map points, that is, <span class="math-tex" data-type="tex">\\(p_{ij} - q_{ij}\\)</span>. Now, we let the system evolve according to the laws of physics. If two map points are far apart while the data points are close, they are attracted together. If they are nearby while the data points are dissimilar, they are repelled.
195222

196223
The final mapping is obtained when the equilibrium is reached.
197224

198225
## Algorithm
199226

200-
Remarkably, this analogy stems exactly from a natural mathematical algorithm. It corresponds to minimizing the Kullback-Leiber divergence between the two distributions <span class="math-tex" data-type="tex">\\(\big(p_{ij}\big)\\)</span> and <span class="math-tex" data-type="tex">\\(\big(q_{ij}\big)\\)</span>:
227+
Remarkably, this analogy stems exactly from a natural mathematical algorithm. It corresponds to minimizing the [Kullback-Leiber](http://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) divergence between the two distributions <span class="math-tex" data-type="tex">\\(\big(p_{ij}\big)\\)</span> and <span class="math-tex" data-type="tex">\\(\big(q_{ij}\big)\\)</span>:
201228

202229
<span class="math-tex" data-type="tex">\\(KL(P||Q) = \sum_{i, j} p_{ij} \, \log \frac{p_{ij}}{q_{ij}}.\\)</span>
203230

204231
This measures the distance between our two similarity matrices.
205232

206233
To minimize this score, we perform a gradient descent. The gradient can be computed analytically:
207234

208-
<span class="math-tex" data-type="tex">\\(\frac{\partial \, K\!L(P || Q)}{\partial y_i} = 4 \sum_j (p_{ij} - q_{ij}) g\left( \left| x_i - x_j\right| \right) \quad \textrm{where} \, g(z) = \frac{z}{1+z^2}.\\)</span>
235+
<span class="math-tex" data-type="tex">\\(\frac{\partial \, K\!L(P || Q)}{\partial y_i} = 4 \sum_j (p_{ij} - q_{ij}) g\left( \left| x_i - x_j\right| \right) u_{ij} \quad \textrm{where} \, g(z) = \frac{z}{1+z^2}.\\)</span>
209236

210-
This gradient expresses the sum of all spring forces applied to map point <span class="math-tex" data-type="tex">\\(i\\)</span>.
237+
Here, <span class="math-tex" data-type="tex">\\(u_{ij}\\)</span> is a unit vector going from <span class="math-tex" data-type="tex">\\(y_j\\)</span> to <span class="math-tex" data-type="tex">\\(y_i\\)</span>. This gradient expresses the sum of all spring forces applied to map point <span class="math-tex" data-type="tex">\\(i\\)</span>.
211238

212-
Now, let's illustrate this process by creating an animation of the convergence.
239+
Let's illustrate this process by creating an animation of the convergence. We'll have to [monkey-patch](http://en.wikipedia.org/wiki/Monkey_patch) the internal `_gradient_descent()` function from scikit-learn's t-SNE implementation in order to register the position of the map points at every iteration.
213240

214241
<pre data-code-language="python"
215242
data-executable="true"
216243
data-type="programlisting">
244+
# This list will contain the positions of the map points at every iteration.
217245
positions = []
218246
def _gradient_descent(objective, p0, it, n_iter, n_iter_without_progress=30,
219247
momentum=0.5, learning_rate=1000.0, min_gain=0.01,
220248
min_grad_norm=1e-7, min_error_diff=1e-7, verbose=0,
221249
args=[]):
250+
# The documentation of this function can be found in scikit-learn's code.
222251
p = p0.copy().ravel()
223252
update = np.zeros_like(p)
224253
gains = np.ones_like(p)
@@ -227,7 +256,9 @@ def _gradient_descent(objective, p0, it, n_iter, n_iter_without_progress=30,
227256
best_iter = 0
228257

229258
for i in range(it, n_iter):
259+
# We append the current position.
230260
positions.append(p.copy())
261+
231262
new_error, grad = objective(p, *args)
232263
error_diff = np.abs(new_error - error)
233264
error = new_error
@@ -256,11 +287,12 @@ def _gradient_descent(objective, p0, it, n_iter, n_iter_without_progress=30,
256287
sklearn.manifold.t_sne._gradient_descent = _gradient_descent
257288
</pre>
258289

290+
Let's run the algorithm again, but this time saving all intermediate positions.
291+
259292
<pre data-code-language="python"
260293
data-executable="true"
261294
data-type="programlisting">
262-
tsne = TSNE()
263-
X_proj = tsne.fit_transform(X)
295+
X_proj = TSNE().fit_transform(X)
264296
</pre>
265297

266298
<pre data-code-language="python"
@@ -270,10 +302,12 @@ X_iter = np.dstack(position.reshape(-1, 2)
270302
for position in positions)
271303
</pre>
272304

305+
We create an animation using [MoviePy](http://zulko.github.io/moviepy/).
306+
273307
<pre data-code-language="python"
274308
data-executable="true"
275309
data-type="programlisting">
276-
f, ax, sc, txts = scatter(X_iter[..., -1], y);
310+
f, ax, sc, txts = scatter(X_iter[..., -1], y)
277311

278312
def make_frame_mpl(t):
279313
i = int(t*40)
@@ -287,10 +321,14 @@ def make_frame_mpl(t):
287321

288322
animation = mpy.VideoClip(make_frame_mpl,
289323
duration=X_iter.shape[2]/40.)
290-
animation.write_gif("anim.gif", fps=20)
324+
animation.write_gif("tsne.gif", fps=20)
291325
</pre>
292326

293-
<img src="anim.gif" />
327+
<img src="tsne.gif" />
328+
329+
We can observe the different phases of the optimization. The details of the algorithm can be found in the original paper.
330+
331+
Let's also create an animation of the similarity matrix of the map points. We'll observe that it's getting closer and closer to the similarity matrix of the data points.
294332

295333
<pre data-code-language="python"
296334
data-executable="true"
@@ -302,13 +340,9 @@ Q = squareform(Q)
302340
f = plt.figure(figsize=(6, 6))
303341
ax = plt.subplot(aspect='equal')
304342
im = ax.imshow(Q, interpolation='none', cmap=pal)
305-
plt.axis('tight');
306-
plt.axis('off');
307-
</pre>
343+
plt.axis('tight')
344+
plt.axis('off')
308345

309-
<pre data-code-language="python"
310-
data-executable="true"
311-
data-type="programlisting">
312346
def make_frame_mpl(t):
313347
i = int(t*40)
314348
n = 1. / (pdist(X_iter[..., i], "sqeuclidean") + 1)
@@ -319,15 +353,17 @@ def make_frame_mpl(t):
319353

320354
animation = mpy.VideoClip(make_frame_mpl,
321355
duration=X_iter.shape[2]/40.)
322-
animation.write_gif("anim2.gif", fps=20)
356+
animation.write_gif("tsne_matrix.gif", fps=20)
323357
</pre>
324358

325-
<img src="anim2.gif" />
359+
<img src="tsne_matrix.gif" />
326360

327361
## The t-Student distribution
328362

329363
Let's now explain the choice of the t-Student distribution for the map points, while a normal distribution is used for the data points. It is well known that the volume of the <span class="math-tex" data-type="tex">\\(N\\)</span>-dimensional ball of radius <span class="math-tex" data-type="tex">\\(r\\)</span> scales as <span class="math-tex" data-type="tex">\\(r^N\\)</span>. When <span class="math-tex" data-type="tex">\\(N\\)</span> is large, if we pick random points uniformly in the ball, most points will be close to the surface, and very few will be near the center.
330364

365+
This is illustrated by the following simulation, showing the distribution of the distances of these points, for different dimensions.
366+
331367
<pre data-code-language="python"
332368
data-executable="true"
333369
data-type="programlisting">
@@ -352,24 +388,24 @@ for i, D in enumerate((2, 5, 10)):
352388
ax.set_title('D=%d' % D, loc='left')
353389
</pre>
354390

355-
When reducing the dimensionality of a dataset, if we used the same Gaussian distribution for the data points and the map points, this mathematical fact would result in an _imbalance_ among the neighbors of a given point. This imbalance would lead to an excess of attraction forces and a sometimes unappealing mapping. This is actually what happens in the original SNE algorithm, by Hinton and Roweis (2002).
391+
When reducing the dimensionality of a dataset, if we used the same Gaussian distribution for the data points and the map points, we could get an _imbalance_ among the neighbors of a given point. This imbalance would lead to an excess of attraction forces and a sometimes unappealing mapping. This is actually what happens in the original SNE algorithm, by Hinton and Roweis (2002).
356392

357-
The t-SNE algorithm works around this problem by using a t-Student with one degree of freedom (or Cauchy) distribution for the map points. This distribution has a much heavier tail than the Gaussian distribution, which _compensates_ the original imbalance. For a given data similarity between two data points, the two corresponding map points will need to be much further apart in order for their similarity to match the data similarity.
393+
The t-SNE algorithm works around this problem by using a t-Student with one degree of freedom (or Cauchy) distribution for the map points. This distribution has a much heavier tail than the Gaussian distribution, which _compensates_ the original imbalance. For a given data similarity between two data points, the two corresponding map points will need to be much further apart in order for their similarity to match the data similarity. This is can be seen in the following plot.
358394

359395
<pre data-code-language="python"
360396
data-executable="true"
361397
data-type="programlisting">
362398
z = np.linspace(0., 5., 1000)
363399
gauss = np.exp(-z**2)
364400
cauchy = 1/(1+z**2)
365-
plt.plot(z, gauss, label='Gaussian distribution');
366-
plt.plot(z, cauchy, label='Cauchy distribution');
401+
plt.plot(z, gauss, label='Gaussian distribution')
402+
plt.plot(z, cauchy, label='Cauchy distribution')
367403
plt.legend();
368404
</pre>
369405

370406
## Conclusion
371407

372-
The t-SNE algorithm provides an effective method to visualize a complex dataset. It successfully uncovers hidden structures in the data, exposing natural clusters or smooth nonlinear variations along the dimensions. It has been implemented in many languages, including Python, and it can be easily used thanks to the scikit-learn library.
408+
The t-SNE algorithm provides an effective method to visualize a complex dataset. It successfully uncovers hidden structures in the data, exposing natural clusters and smooth nonlinear variations along the dimensions. It has been implemented in many languages, including Python, and it can be easily used thanks to the scikit-learn library.
373409

374410
The references below link to some optimizations and improvements that can be made to the algorithm and implementations. In particular, the algorithm described here is quadratic in the number of samples, which makes it unscalable to large datasets. One could for example obtain an <span class="math-tex" data-type="tex">\\(O(N \log N)\\)</span> complexity by using the Barnes-Hut algorithm to accelerate the N-body simulation via a quadtree or an octree.
375411

0 commit comments

Comments
 (0)