Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 13 additions & 19 deletions squeeze/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,24 @@

# Import Rust-based algorithms
try:
from ._hnsw_backend import (
PCA,
TSNE,
MDS,
Isomap,
LLE,
PHATE,
TriMap,
PaCMAP,
)
from . import _hnsw_backend as _rust_backend # type: ignore[import-not-found]
except ImportError as e:
warn(
f"Rust backend not available: {e}. Some algorithms may not be available.",
stacklevel=2,
category=ImportWarning,
)
# Create dummy classes
PCA = None
TSNE = None
MDS = None
Isomap = None
LLE = None
PHATE = None
TriMap = None
PaCMAP = None
_rust_backend = None

RustUMAP = getattr(_rust_backend, "UMAP", None)
PCA = getattr(_rust_backend, "PCA", None)
TSNE = getattr(_rust_backend, "TSNE", None)
MDS = getattr(_rust_backend, "MDS", None)
Isomap = getattr(_rust_backend, "Isomap", None)
LLE = getattr(_rust_backend, "LLE", None)
PHATE = getattr(_rust_backend, "PHATE", None)
TriMap = getattr(_rust_backend, "TriMap", None)
PaCMAP = getattr(_rust_backend, "PaCMAP", None)

try:
with catch_warnings():
Expand Down Expand Up @@ -111,6 +104,7 @@ def __init__(self, **_kwds: object) -> None:
__all__ = [
# Core UMAP
"UMAP",
"RustUMAP",
"AlignedUMAP",
"ParametricUMAP",
# Rust-based DR algorithms
Expand Down
43 changes: 24 additions & 19 deletions squeeze/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,21 @@
from .umap_ import UMAP

try:
from ._hnsw_backend import (
PCA,
TSNE,
MDS,
Isomap,
LLE,
PHATE,
TriMap,
PaCMAP,
)

from . import _hnsw_backend as _rust_backend # type: ignore[import-not-found]
RUST_BACKEND_AVAILABLE = True
except ImportError:
RUST_BACKEND_AVAILABLE = False
PCA = None
TSNE = None
MDS = None
Isomap = None
LLE = None
PHATE = None
TriMap = None
PaCMAP = None
_rust_backend = None

RustUMAP = getattr(_rust_backend, "UMAP", None)
PCA = getattr(_rust_backend, "PCA", None)
TSNE = getattr(_rust_backend, "TSNE", None)
MDS = getattr(_rust_backend, "MDS", None)
Isomap = getattr(_rust_backend, "Isomap", None)
LLE = getattr(_rust_backend, "LLE", None)
PHATE = getattr(_rust_backend, "PHATE", None)
TriMap = getattr(_rust_backend, "TriMap", None)
PaCMAP = getattr(_rust_backend, "PaCMAP", None)


@dataclass
Expand Down Expand Up @@ -83,6 +76,18 @@ def _register_defaults(self):
if not RUST_BACKEND_AVAILABLE:
return

# Rust UMAP (minimal core implementation)
if RustUMAP is not None:
self.register(
Strategy(
name="umap_rust",
algorithm_class=RustUMAP,
default_params={"n_components": 2, "n_neighbors": 15},
description="UMAP (Rust backend; dense + exact kNN)",
category="nonlinear",
)
)

# PCA
self.register(
Strategy(
Expand Down
21 changes: 21 additions & 0 deletions squeeze/tests/test_rust_umap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np
import pytest


def test_rust_umap_fit_transform_shape() -> None:
from squeeze import RustUMAP

if RustUMAP is None:
pytest.skip("Rust backend not available")

rng = np.random.default_rng(42)
X = rng.normal(size=(80, 12)).astype(np.float64)

reducer = RustUMAP(n_components=2, n_neighbors=10, n_epochs=25, random_state=42)
emb = reducer.fit_transform(X)

assert emb.shape == (80, 2)
assert np.isfinite(emb).all()
# embedding is centered by the backend
assert np.allclose(emb.mean(axis=0), 0.0, atol=1e-3)

2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub mod lle;
pub mod phate;
pub mod trimap;
pub mod pacmap;
pub mod umap;

#[cfg(not(test))]
#[pyo3::pymodule]
Expand All @@ -35,6 +36,7 @@ fn _hnsw_backend(_py: pyo3::Python, m: &pyo3::Bound<'_, pyo3::types::PyModule>)
m.add_class::<phate::PHATE>()?;
m.add_class::<trimap::TriMap>()?;
m.add_class::<pacmap::PaCMAP>()?;
m.add_class::<umap::UMAP>()?;

Ok(())
}
Loading
Loading