Skip to content

Commit

Permalink
Added some example files
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidMChan committed May 11, 2020
1 parent 43c9a93 commit 97af49e
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 2 deletions.
31 changes: 31 additions & 0 deletions examples/cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from keras.datasets import cifar10
import time
import numpy as np
from tsnecuda import TSNE
import matplotlib.pyplot as plt

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

print(y_train.shape)
print(x_train.shape)

time_start = time.time()
tsne = TSNE(n_iter=5000, verbose=1, perplexity=10000, num_neighbors=128)
tsne_results = tsne.fit_transform(x_train.reshape(x_train.shape[0],np.prod(x_train.shape[1:])))


print(tsne_results.shape)

# Create the figure
fig = plt.figure( figsize=(8,8) )
ax = fig.add_subplot(1, 1, 1, title='TSNE' )

# Create the scatter
ax.scatter(
x=tsne_results[:,0],
y=tsne_results[:,1],
c=y_train,
cmap=plt.cm.get_cmap('Paired'),
alpha=0.4,
s=0.5)
plt.show()
30 changes: 30 additions & 0 deletions examples/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from keras.datasets import mnist
import time
from tsnecuda import TSNE
import matplotlib.pyplot as plt

(x_train, y_train), (x_test, y_test) = mnist.load_data()

print(y_train.shape)
print(x_train.shape)

time_start = time.time()
tsne = TSNE(n_iter=700, verbose=1, num_neighbors=4)
tsne_results = tsne.fit_transform(x_train.reshape(60000,-1))


print(tsne_results.shape)

# Create the figure
fig = plt.figure( figsize=(8,8) )
ax = fig.add_subplot(1, 1, 1, title='TSNE' )

# Create the scatter
ax.scatter(
x=tsne_results[:,0],
y=tsne_results[:,1],
c=y_train,
cmap=plt.cm.get_cmap('Paired'),
alpha=0.4,
s=0.5)
plt.show()
4 changes: 2 additions & 2 deletions src/fit_tsne.cu
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ void tsnecuda::RunTsne(tsnecuda::Options &opt,
cudaDeviceSynchronize();
knn_squared_distances_device.clear();
knn_squared_distances_device.shrink_to_fit();
knn_indices_long_device.clear();
knn_indices_long_device.shrink_to_fit();
// knn_indices_long_device.clear();
// knn_indices_long_device.shrink_to_fit();
delete[] knn_squared_distances;
delete[] knn_indices;

Expand Down

0 comments on commit 97af49e

Please sign in to comment.