Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
5ab86e6
Add NeighborFinder classes
candytaco Jan 26, 2026
3604c8b
factor out KNN exclusion radius checking to property
candytaco Jan 26, 2026
e7c6b10
update references to property
candytaco Jan 26, 2026
e5b7b66
factor out lib indices validation from FindNeighbors
candytaco Jan 26, 2026
6f8cc9a
factor actual number of neighbors to pquery for (knn_) out to property
candytaco Jan 26, 2026
a757306
factor out indexKNN index remapping to class method
candytaco Jan 26, 2026
5e04a06
add option to EDM classes to switch between kDTree and pairwise dista…
candytaco Jan 26, 2026
b4a0975
update CCM to make use of pairwise distance matrices in iteration
candytaco Jan 26, 2026
35c720b
update simplex indexing
candytaco Jan 26, 2026
2761149
add neighbor algorithm to API functions
candytaco Jan 26, 2026
c45c213
fix typos
candytaco Jan 26, 2026
5ef5011
fix typos, and pdist Simplex initialize distance matrix before CCM
candytaco Jan 26, 2026
260e572
fix indexing function argument order
candytaco Jan 26, 2026
17ed9c2
fix typo in CCM pdist implementation
candytaco Jan 26, 2026
87ce263
move FindNeighbors into EDM class file
candytaco Jan 26, 2026
28bb2b5
use single np.allclose with atol to check array equality instead of c…
candytaco Jan 26, 2026
9dc3ab8
update FindNeighbors docstring to not specifically reference KDTree
candytaco Jan 26, 2026
ac56ac9
Merge remote-tracking branch 'origin/main' into main
candytaco Jan 26, 2026
01bfd42
break apart KNN index-to lib index remapping function to separate ind…
candytaco Jan 28, 2026
8e96d56
Add Pdist neighbor finder option to take an exclusion mask
candytaco Jan 28, 2026
aecf7ab
add EDM method to build an exclusion mask for neighbors to ignore
candytaco Jan 28, 2026
06e54d9
fix a thing where degenerate neighbors using np.delete can potentiall…
candytaco Jan 28, 2026
1fb6b71
EDM class to give exclusion mask to pdist neighbor finder
candytaco Jan 28, 2026
7c17cd1
CCM to make use of embeded exclusion matrix
candytaco Jan 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions src/pyEDM/API.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def Simplex( dataFrame = None,
verbose = False,
showPlot = False,
ignoreNan = True,
returnObject = False ):
returnObject = False,
neighbor_algorithm = 'kdtree'):
Comment on lines 107 to +110
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A new public neighbor_algorithm option is added for Simplex/SMap, but the test suite does not appear to exercise neighbor_algorithm='pdist' for these APIs. Add a unit/integration test that runs EDM.Simplex(..., neighbor_algorithm='pdist') and EDM.SMap(..., neighbor_algorithm='pdist') on a small non-degenerate dataset (no tied neighbors) to validate the new code path and prevent regressions.

Copilot uses AI. Check for mistakes.
'''Simplex prediction of dataFrame target from columns.'''

# Instantiate SimplexClass object
Expand All @@ -128,7 +129,8 @@ def Simplex( dataFrame = None,
generateSteps = generateSteps,
generateConcat = generateConcat,
ignoreNan = ignoreNan,
verbose = verbose )
verbose = verbose,
neighbor_algorithm = neighbor_algorithm)

if generateSteps :
S.Generate()
Expand Down Expand Up @@ -166,7 +168,8 @@ def SMap( dataFrame = None,
ignoreNan = True,
showPlot = False,
verbose = False,
returnObject = False ):
returnObject = False,
neighbor_algorithm = 'kdtree'):
'''S-Map prediction of dataFrame target from columns.'''

# Validate solver if one was provided
Expand Down Expand Up @@ -206,7 +209,8 @@ def SMap( dataFrame = None,
generateSteps = generateSteps,
generateConcat = generateConcat,
ignoreNan = ignoreNan,
verbose = verbose )
verbose = verbose,
neighbor_algorithm = neighbor_algorithm)

if generateSteps :
S.Generate()
Expand Down Expand Up @@ -248,7 +252,9 @@ def CCM( dataFrame = None,
sequential = False,
verbose = False,
showPlot = False,
returnObject = False ) :
returnObject = False,
neighbor_algorithm = 'pdist'
) :
'''Convergent Cross Mapping.'''

# Instantiate CCMClass object
Expand All @@ -271,7 +277,8 @@ def CCM( dataFrame = None,
ignoreNan = ignoreNan,
mpMethod = mpMethod,
sequential = sequential,
verbose = verbose )
verbose = verbose,
neighbor_algorithm = neighbor_algorithm)

# Embedding of Forward & Reverse mapping
C.FwdMap.EmbedData()
Expand Down
49 changes: 33 additions & 16 deletions src/pyEDM/CCM.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from pandas import DataFrame, concat
from numpy import array, exp, fmax, divide, mean, nan, roll, sum, zeros
from numpy.random import default_rng
import numpy as np

from .NeighborFinder import PairwiseDistanceNeighborFinder
# local modules
from .Simplex import Simplex as SimplexClass
from .AuxFunc import ComputeError, IsIterable
Expand Down Expand Up @@ -34,7 +36,8 @@ def __init__( self,
ignoreNan = True,
mpMethod = None,
sequential = False,
verbose = False ):
verbose = False,
neighbor_algorithm = 'pdist'):
'''Initialize CCM.'''

# Assign parameters from API arguments
Expand All @@ -58,6 +61,7 @@ def __init__( self,
self.mpMethod = mpMethod
self.sequential = sequential
self.verbose = verbose
self.neighbor_algorithm = neighbor_algorithm

# Set full lib & pred
self.lib = self.pred = [ 1, self.Data.shape[0] ]
Expand Down Expand Up @@ -88,7 +92,8 @@ def __init__( self,
validLib = validLib,
noTime = noTime,
ignoreNan = ignoreNan,
verbose = verbose )
verbose = verbose,
neighbor_algorithm = neighbor_algorithm)

self.RevMap = SimplexClass( dataFrame = dataFrame,
columns = target,
Expand All @@ -104,7 +109,8 @@ def __init__( self,
validLib = validLib,
noTime = noTime,
ignoreNan = ignoreNan,
verbose = verbose )
verbose = verbose,
neighbor_algorithm = neighbor_algorithm)

#-------------------------------------------------------------------
# Methods
Expand Down Expand Up @@ -193,6 +199,9 @@ def CrossMap( self, direction ) :
libRhoMap = {} # Output dict libSize key : mean rho value
libStatMap = {} # Output dict libSize key : list of ComputeError dicts

if self.neighbor_algorithm == 'pdist':
S.FindNeighbors() # need to initialize the pairwise distance matrix

# Loop for library sizes
for libSize in self.libSizes :
rhos = zeros( self.sample )
Expand All @@ -202,31 +211,39 @@ def CrossMap( self, direction ) :
# Loop for subsamples
for s in range( self.sample ) :
# Generate library row indices for this subsample
rng_i = RNG.choice( lib_i, size = min( libSize, N_lib_i ),
replace = False )

S.lib_i = rng_i

S.FindNeighbors() # Depends on S.lib_i
if self.neighbor_algorithm == 'kdtree':
rng_i = RNG.choice( lib_i, size = min( libSize, N_lib_i ),
replace = False )
S.lib_i = rng_i
S.FindNeighbors() # Depends on S.lib_i
neighbor_distances = S.knn_distances
neighbor_indices = S.knn_neighbors
else:
rng_i = RNG.choice(np.arange(S.neighbor_finder.distanceMatrix.shape[0]),
size = min(libSize, N_lib_i),
replace = False)
d = S.neighbor_finder.distanceMatrix.copy()
mask = np.ones(d.shape[0], dtype = bool)
mask[rng_i] = False
d[mask, :] = np.inf # artificially make all the other ones far awa
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in comment: "far awa" should be "far away".

Suggested change
d[mask, :] = np.inf # artificially make all the other ones far awa
d[mask, :] = np.inf # artificially make all the other ones far away

Copilot uses AI. Check for mistakes.
neighbor_distances, raw_indices = PairwiseDistanceNeighborFinder.find_neighbors(d, S.knn)
neighbor_indices = S.map_knn_indices_to_library_indices(raw_indices)

# Code from Simplex:Project ---------------------------------
# First column is minimum distance of all N pred rows
minDistances = S.knn_distances[:,0]
minDistances = neighbor_distances[:,0]
# In case there is 0 in minDistances: minWeight = 1E-6
minDistances = fmax( minDistances, 1E-6 )

# Divide each column of N x k knn_distances by minDistances
scaledDistances = divide(S.knn_distances, minDistances[:,None])
scaledDistances = divide(neighbor_distances, minDistances[:,None])
weights = exp( -scaledDistances ) # Npred x k
weightRowSum = sum( weights, axis = 1 ) # Npred x 1

# Matrix of knn_neighbors + Tp defines library target values
knn_neighbors_Tp = S.knn_neighbors + self.Tp # Npred x k
knn_neighbors_Tp = neighbor_indices + self.Tp # Npred x k

libTargetValues = zeros( knn_neighbors_Tp.shape ) # Npred x k
for j in range( knn_neighbors_Tp.shape[1] ) :
libTargetValues[ :, j ][ :, None ] = \
S.targetVec[ knn_neighbors_Tp[ :, j ] ]
libTargetValues = S.targetVec[knn_neighbors_Tp].squeeze()
# Code from Simplex:Project ----------------------------------

# Projection is average of weighted knn library target values
Expand Down
Loading