Skip to content

Commit

Permalink
added some bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AmitaiYacobi committed Jun 29, 2023
1 parent 779834a commit a7a7671
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 25 deletions.
24 changes: 4 additions & 20 deletions examples/cluster_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,33 +36,17 @@ def main():
y_train = torch.cat([y_train, y_test])

spectralnet = SpectralNet(
n_clusters=2,
should_use_ae=False,
should_use_siamese=False,
spectral_batch_size=712,
spectral_epochs=40,
spectral_is_local_scale=False,
spectral_n_nbg=8,
spectral_scale_k=2,
spectral_lr=1e-2,
spectral_hiddens=[128, 128, 2],
n_clusters=10,
should_use_ae=True,
should_use_siamese=True,
)

# spectralnet = SpectralNet(
# n_clusters=10,
# should_use_ae=True,
# should_use_siamese=True,
# ae_epochs=2,
# siamese_epochs=2,
# spectral_epochs=2,
# )
spectralnet.fit(x_train)
cluster_assignments = spectralnet.predict(x_train)
embeddings = spectralnet.embeddings_

if y_train is not None:
y = y_train.detach().cpu().numpy()
acc_score = Metrics.acc_score(cluster_assignments, y, n_clusters=2)
acc_score = Metrics.acc_score(cluster_assignments, y, n_clusters=10)
nmi_score = Metrics.nmi_score(cluster_assignments, y)
print(f"ACC: {np.round(acc_score, 3)}")
print(f"NMI: {np.round(nmi_score, 3)}")
Expand Down
4 changes: 2 additions & 2 deletions examples/cluster_twomoons.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from data import load_data

from ..src.spectralnet import Metrics
from ..src.spectralnet import SpectralNet
from spectralnet import Metrics
from spectralnet import SpectralNet


class InvalidMatrixException(Exception):
Expand Down
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
[metadata]
name = spectralnet
version = 0.0.4
version = 0.0.5
author = Amitai
description = Spectral Clustering Using Deep Neural Networks
long_description = file: README.md
long_description_content_type = text/markdown
url = https://github.com/AmitaiYacobi/SpectralNet.git
url = https://github.com/shaham-lab/SpectralNet.git
project_urls =
Bug Tracker = https://github.com/AmitaiYacobi/SpectralNet/issues
Bug Tracker = https://github.com/shaham-lab/SpectralNet/issues
classifiers =
Programming Language :: Python :: 3
License :: OSI Approved :: MIT License
Expand Down
Empty file added src/tests/__init__.py
Empty file.

0 comments on commit a7a7671

Please sign in to comment.