Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions squeeze/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
PHATE = getattr(_rust_backend, "PHATE", None)
TriMap = getattr(_rust_backend, "TriMap", None)
PaCMAP = getattr(_rust_backend, "PaCMAP", None)
PLSCAN = getattr(_rust_backend, "PLSCAN", None)
PLSCANBackbone = getattr(_rust_backend, "PLSCANBackbone", None)

try:
with catch_warnings():
Expand Down Expand Up @@ -116,6 +118,8 @@ def __init__(self, **_kwds: object) -> None:
"PHATE",
"TriMap",
"PaCMAP",
"PLSCAN",
"PLSCANBackbone",
# Composition utilities
"AdaptiveDR",
"DRPipeline",
Expand Down
22 changes: 22 additions & 0 deletions squeeze/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
PHATE = getattr(_rust_backend, "PHATE", None)
TriMap = getattr(_rust_backend, "TriMap", None)
PaCMAP = getattr(_rust_backend, "PaCMAP", None)
PLSCANBackbone = getattr(_rust_backend, "PLSCANBackbone", None)


@dataclass
Expand Down Expand Up @@ -194,6 +195,27 @@ def _register_defaults(self):
)
)

# PLSCAN backbone (landmarks + scale selection + soft must-link)
if PLSCANBackbone is not None:
self.register(
Strategy(
name="plscan_backbone",
algorithm_class=PLSCANBackbone,
default_params={
"n_components": 2,
"min_samples": 5,
"rep_strategy": "high_prob",
"reps_per_cluster": 1,
"neighbor_scale": 1.0,
"must_link_weight": 0.1,
"interpolation_k": 3,
"restrict_to_cluster": True,
},
description="PLSCAN-backed landmark spectral embedding with soft must-link regularization",
category="hybrid",
)
)

def register(self, strategy: Strategy) -> None:
"""Register a new strategy."""
self._strategies[strategy.name.lower()] = strategy
Expand Down
39 changes: 39 additions & 0 deletions squeeze/tests/test_plscan_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import numpy as np

from squeeze import PLSCAN, PLSCANBackbone, get_strategy


def test_plscan_cluster_shapes(iris) -> None:
x = iris.data.astype(np.float64)
clusterer = PLSCAN(min_samples=5)
labels = clusterer.fit_predict(x)

assert labels.shape == (x.shape[0],)
assert clusterer.probabilities_.shape == (x.shape[0],)
assert np.all((clusterer.probabilities_ >= 0.0) & (clusterer.probabilities_ <= 1.0))
assert clusterer.trace_min_size_.shape == clusterer.trace_persistence_.shape


def test_plscan_backbone_fit_transform_shapes(iris) -> None:
x = iris.data.astype(np.float64)
reducer = PLSCANBackbone(
n_components=2,
min_samples=5,
rep_strategy="high_prob",
reps_per_cluster=1,
must_link_weight=0.1,
interpolation_k=3,
restrict_to_cluster=True,
)
embedding = reducer.fit_transform(x)

assert embedding.shape == (x.shape[0], 2)
assert reducer.labels_.shape == (x.shape[0],)
assert reducer.probabilities_.shape == (x.shape[0],)
assert reducer.rep_indices_.ndim == 1
assert reducer.trace_min_size_.shape == reducer.trace_persistence_.shape


def test_strategy_registry_has_plscan_backbone() -> None:
strategy = get_strategy("plscan_backbone")
assert strategy.name == "plscan_backbone"
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ pub mod phate;
pub mod trimap;
pub mod pacmap;
pub mod umap;
pub mod plscan;
pub mod plscan_backbone;

#[cfg(not(test))]
#[pyo3::pymodule]
Expand All @@ -37,6 +39,8 @@ fn _hnsw_backend(_py: pyo3::Python, m: &pyo3::Bound<'_, pyo3::types::PyModule>)
m.add_class::<trimap::TriMap>()?;
m.add_class::<pacmap::PaCMAP>()?;
m.add_class::<umap::UMAP>()?;
m.add_class::<plscan::PLSCAN>()?;
m.add_class::<plscan_backbone::PLSCANBackbone>()?;

Ok(())
}
Loading