Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/tdamapper/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""
import warnings

import numpy as np


def warn_deprecated(deprecated, substitute):
msg = f'{deprecated} is deprecated and will be removed in a future version. Use {substitute} instead.'
Expand All @@ -17,6 +19,52 @@ def warn_user(msg):
warnings.warn(msg, UserWarning, stacklevel=2)


class EstimatorMixin:

def _is_sparse(self, X):
# simple alternative use scipy.sparse.issparse
return hasattr(X, 'toarray')

def _validate_X_y(self, X, y):
if self._is_sparse(X):
raise ValueError('Sparse data not supported.')

X = np.asarray(X)
y = np.asarray(y)

if X.size == 0:
msg = f'0 feature(s) (shape={X.shape}) while a minimum of 1 is required.'
raise ValueError(msg)

if y.size == 0:
msg = f'0 feature(s) (shape={y.shape}) while a minimum of 1 is required.'
raise ValueError(msg)

if X.ndim == 1:
raise ValueError('1d-arrays not supported.')

if np.iscomplexobj(X) or np.iscomplexobj(y):
raise ValueError('Complex data not supported.')

if X.dtype == np.object_:
X = np.array(X, dtype=float)

if y.dtype == np.object_:
y = np.array(y, dtype=float)

if np.isnan(X).any() or np.isinf(X).any() or \
np.isnan(y).any() or np.isinf(y).any():
raise ValueError('NaNs or infinite values not supported.')

return X, y

def fit(self, X, y=None):
X, y = self._validate_X_y(X, y)
res = super().fit(X, y)
self.n_features_in_ = X.shape[1]
return res


class ParamsMixin:
"""
Mixin to add setters and getters for public parameters, compatible with
Expand Down
59 changes: 35 additions & 24 deletions src/tdamapper/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

from tdamapper.core import mapper_connected_components, TrivialCover
import tdamapper.core
from tdamapper._common import ParamsMixin, clone, warn_deprecated
from tdamapper._common import (
ParamsMixin,
EstimatorMixin,
clone,
warn_deprecated,
)


class TrivialClustering(tdamapper.core.TrivialClustering):
Expand Down Expand Up @@ -33,7 +38,34 @@ def __init__(self, clustering=None, verbose=True):
super().__init__(clustering, verbose)


class MapperClustering(ParamsMixin):
class _MapperClustering:

def __init__(self, cover=None, clustering=None, n_jobs=1):
self.cover = cover
self.clustering = clustering
self.n_jobs = n_jobs

def fit(self, X, y=None):
cover = TrivialCover() if self.cover is None \
else self.cover
cover = clone(cover)
clustering = TrivialClustering() if self.clustering is None \
else self.clustering
clustering = clone(clustering)
n_jobs = self.n_jobs
y = X if y is None else y
itm_lbls = mapper_connected_components(
X,
y,
cover,
clustering,
n_jobs=n_jobs,
)
self.labels_ = [itm_lbls[i] for i, _ in enumerate(X)]
return self


class MapperClustering(EstimatorMixin, _MapperClustering, ParamsMixin):
"""
A clustering algorithm based on the Mapper graph.

Expand All @@ -60,25 +92,4 @@ class MapperClustering(ParamsMixin):
"""

def __init__(self, cover=None, clustering=None, n_jobs=1):
self.cover = cover
self.clustering = clustering
self.n_jobs = n_jobs

def fit(self, X, y=None):
cover = TrivialCover() if self.cover is None \
else self.cover
cover = clone(cover)
clustering = TrivialClustering() if self.clustering is None \
else self.clustering
clustering = clone(clustering)
n_jobs = self.n_jobs
y = X if y is None else y
itm_lbls = mapper_connected_components(
X,
y,
cover,
clustering,
n_jobs=n_jobs,
)
self.labels_ = [itm_lbls[i] for i, _ in enumerate(X)]
return self
super().__init__(cover=cover, clustering=clustering, n_jobs=n_jobs)
89 changes: 57 additions & 32 deletions src/tdamapper/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from joblib import Parallel, delayed

from tdamapper.utils.unionfind import UnionFind
from tdamapper._common import ParamsMixin, clone
from tdamapper._common import ParamsMixin, EstimatorMixin, clone


ATTR_IDS = 'ids'
Expand Down Expand Up @@ -364,7 +364,53 @@ def apply(self, X):
yield list(range(0, len(X)))


class MapperAlgorithm(ParamsMixin):
class _MapperAlgorithm:

def __init__(
self,
cover=None,
clustering=None,
failsafe=True,
verbose=True,
n_jobs=1,
):
self.cover = cover
self.clustering = clustering
self.failsafe = failsafe
self.verbose = verbose
self.n_jobs = n_jobs

def fit(self, X, y=None):
self.__cover = TrivialCover() if self.cover is None \
else self.cover
self.__clustering = TrivialClustering() if self.clustering is None \
else self.clustering
self.__verbose = self.verbose
self.__failsafe = self.failsafe
if self.__failsafe:
self.__clustering = FailSafeClustering(
clustering=self.__clustering,
verbose=self.__verbose,
)
self.__cover = clone(self.__cover)
self.__clustering = clone(self.__clustering)
self.__n_jobs = self.n_jobs
y = X if y is None else y
self.graph_ = mapper_graph(
X,
y,
self.__cover,
self.__clustering,
n_jobs=self.__n_jobs,
)
return self

def fit_transform(self, X, y):
self.fit(X, y)
return self.graph_


class MapperAlgorithm(EstimatorMixin, _MapperAlgorithm, ParamsMixin):
"""
A class for creating and analyzing Mapper graphs.

Expand Down Expand Up @@ -412,11 +458,13 @@ def __init__(
verbose=True,
n_jobs=1,
):
self.cover = cover
self.clustering = clustering
self.failsafe = failsafe
self.verbose = verbose
self.n_jobs = n_jobs
super().__init__(
cover=cover,
clustering=clustering,
failsafe=failsafe,
verbose=verbose,
n_jobs=n_jobs,
)

def fit(self, X, y=None):
"""
Expand All @@ -431,29 +479,7 @@ def fit(self, X, y=None):
:type y: array-like of shape (n, k) or list-like of length n
:return: The object itself.
"""
self.__cover = TrivialCover() if self.cover is None \
else self.cover
self.__clustering = TrivialClustering() if self.clustering is None \
else self.clustering
self.__verbose = self.verbose
self.__failsafe = self.failsafe
if self.__failsafe:
self.__clustering = FailSafeClustering(
clustering=self.__clustering,
verbose=self.__verbose,
)
self.__cover = clone(self.__cover)
self.__clustering = clone(self.__clustering)
self.__n_jobs = self.n_jobs
y = X if y is None else y
self.graph_ = mapper_graph(
X,
y,
self.__cover,
self.__clustering,
n_jobs=self.__n_jobs,
)
return self
return super().fit(X, y)

def fit_transform(self, X, y):
"""
Expand All @@ -469,8 +495,7 @@ def fit_transform(self, X, y):
:return: The Mapper graph.
:rtype: :class:`networkx.Graph`
"""
self.fit(X, y)
return self.graph_
return super().fit_transform(X, y)


class FailSafeClustering(ParamsMixin):
Expand Down
64 changes: 6 additions & 58 deletions tests/test_unit_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np


from sklearn.utils.estimator_checks import check_estimator

from tdamapper.core import MapperAlgorithm
Expand All @@ -21,57 +20,6 @@ def euclidean(x, y):
return np.linalg.norm(x - y)


class ValidationMixin:

def _is_sparse(self, X):
# in alternative use scipy.sparse.issparse
return hasattr(X, 'toarray')

def _validate_X_y(self, X, y):
if self._is_sparse(X):
raise ValueError('Sparse data not supported.')

if X.size == 0:
msg = f'0 feature(s) (shape={X.shape}) while a minimum of 1 is required.'
raise ValueError(msg)

if y.size == 0:
msg = f'0 feature(s) (shape={y.shape}) while a minimum of 1 is required.'
raise ValueError(msg)

if X.ndim == 1:
raise ValueError('1d-arrays not supported.')

if np.iscomplexobj(X) or np.iscomplexobj(y):
raise ValueError('Complex data not supported.')

if X.dtype == np.object_:
X = np.array(X, dtype=float)

if y.dtype == np.object_:
y = np.array(y, dtype=float)

if np.isnan(X).any() or np.isinf(X).any() or \
np.isnan(y).any() or np.isinf(y).any():
raise ValueError('NaNs or infinite values not supported.')

return X, y

def fit(self, X, y=None):
X, y = self._validate_X_y(X, y)
res = super().fit(X, y)
self.n_features_in_ = X.shape[1]
return res


class MapperEstimator(ValidationMixin, MapperAlgorithm):
pass


class MapperClusteringEstimator(ValidationMixin, MapperClustering):
pass


class TestSklearn(unittest.TestCase):

setup_logging()
Expand All @@ -83,25 +31,25 @@ def run_tests(self, estimator):
check(est)

def test_trivial(self):
est = MapperEstimator()
est = MapperAlgorithm()
self.run_tests(est)

def test_ball(self):
est = MapperEstimator(cover=BallCover(metric=euclidean))
est = MapperAlgorithm(cover=BallCover(metric=euclidean))
self.run_tests(est)

def test_knn(self):
est = MapperEstimator(cover=KNNCover(metric=euclidean))
est = MapperAlgorithm(cover=KNNCover(metric=euclidean))
self.run_tests(est)

def test_cubical(self):
est = MapperEstimator(cover=CubicalCover())
est = MapperAlgorithm(cover=CubicalCover())
self.run_tests(est)

def test_clustering_trivial(self):
est = MapperClusteringEstimator()
est = MapperClustering()
self.run_tests(est)

def test_clustering_ball(self):
est = MapperClusteringEstimator(cover=BallCover(metric=euclidean))
est = MapperClustering(cover=BallCover(metric=euclidean))
self.run_tests(est)