Skip to content

Commit c8c60f8

Browse files
committed
Adding little functions to automatically add landmarks.
1 parent f123b91 commit c8c60f8

File tree

3 files changed

+1224
-64
lines changed

3 files changed

+1224
-64
lines changed

notebooks/MNIST_Landmarks.ipynb

Lines changed: 75 additions & 64 deletions
Large diffs are not rendered by default.

notebooks/SCRATCH_MNIST_Landmarks.ipynb

Lines changed: 1065 additions & 0 deletions
Large diffs are not rendered by default.

umap/parametric_umap.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ def __init__(
114114
self.global_correlation_loss_weight = global_correlation_loss_weight
115115
self.landmark_loss_fn = landmark_loss_fn
116116
self.landmark_loss_weight = landmark_loss_weight
117+
self.prev_epoch_X = None
118+
self.window_vals = None
117119

118120
self.reconstruction_validation = (
119121
reconstruction_validation # holdout data for reconstruction acc
@@ -174,6 +176,16 @@ def fit(self, X, y=None, precomputed_distances=None, landmark_positions=None):
174176
The desired position in low-dimensional space of each sample in X.
175177
Points that are not landmarks should have nan coordinates.
176178
"""
179+
if (self.prev_epoch_X is not None)&(landmark_positions is None):
180+
# Add the landmark points for training, then make a landmark vector. NaN corresponds to no landmark information.
181+
landmark_positions = np.stack(
182+
[np.array([np.nan, np.nan])]*X.shape[0] + list(
183+
self.transform(
184+
self.prev_epoch_X
185+
)
186+
)
187+
)
188+
X = np.concatenate((X, self.prev_epoch_X))
177189

178190
if landmark_positions is not None:
179191
len_X = len(X)
@@ -230,6 +242,16 @@ def fit_transform(
230242
The desired position in low-dimensional space of each sample in X.
231243
Points that are not landmarks should have nan coordinates.
232244
"""
245+
if (self.prev_epoch_X is not None)&(landmark_positions is None):
246+
# Add the landmark points for training, then make a landmark vector. NaN corresponds to no landmark information.
247+
landmark_positions = np.stack(
248+
[np.array([np.nan, np.nan])]*X.shape[0] + list(
249+
self.transform(
250+
self.prev_epoch_X
251+
)
252+
)
253+
)
254+
X = np.concatenate((X, self.prev_epoch_X))
233255

234256
if landmark_positions is not None:
235257
len_X = len(X)
@@ -473,6 +495,68 @@ def save(self, save_location, verbose=True):
473495
if verbose:
474496
print("Pickle of ParametricUMAP model saved to {}".format(model_output))
475497

498+
def add_landmarks(self, X, sample_pct=0.01, sample_mode = "uniform", landmark_loss_weight = 0.01,curr_window_vals = 1.0, old_window_thresh = 0.0):
499+
"""Add some points from a dataset X as "landmarks" to be approximately preserved after retraining.
500+
501+
Parameters
502+
----------
503+
X : array, shape (n_samples, n_features)
504+
Old data to be retained.
505+
sample_pct : float, optional
506+
Percentage of old data to use as landmarks.
507+
sample_mode : str, optional
508+
Method for sampling points. Currently only "uniform" and "sliding_window" are supported.
509+
landmark_loss_weight : float, optional
510+
Multiplier for landmark loss function.
511+
curr_window_vals: array, shape (n_samples,) or float, optional
512+
In "sliding_window" mode, the window value to give to the current points.
513+
old_window_thresh: float, optional
514+
In "sliding_window" mode, points with values below this value are dropped.
515+
516+
"""
517+
self.sample_pct = sample_pct
518+
self.sample_mode = sample_mode
519+
self.landmark_loss_weight = landmark_loss_weight
520+
521+
if self.sample_mode == "uniform":
522+
self.prev_epoch_idx = list(np.random.choice(range(X.shape[0]), int(X.shape[0]*sample_pct), replace=False))
523+
self.prev_epoch_X = X[self.prev_epoch_idx]
524+
elif self.sample_mode == "sliding_window":
525+
if (self.window_vals is None)&(self.prev_epoch_X is not None):
526+
raise ValueError(
527+
"Use remove_landmarks to remove previous landmarks before adding sliding windows."
528+
)
529+
if type(curr_window_vals) is float:
530+
curr_window_vals = np.array([curr_window_vals]*X.shape[0])
531+
new_idx = list(np.random.choice(range(X.shape[0]), int(X.shape[0]*self.sample_pct), replace=False))
532+
new_X = X[new_idx]
533+
new_window_vals = curr_window_vals[new_idx]
534+
# update self.prev_epoch_idx, self.prev_epoch_X, self.window_vals by FIRST concatenating with the old values, THEN throwing away everything that fails the threshold.
535+
if self.window_vals is None:
536+
self.prev_epoch_idx = new_idx
537+
self.window_vals = new_window_vals
538+
self.prev_epoch_X = new_X
539+
else:
540+
print(self.prev_epoch_X.shape) # ZZX: Kill this before release.
541+
print(len(new_idx))
542+
self.prev_epoch_idx = self.prev_epoch_idx.extend(new_idx)
543+
self.window_vals = np.stack((self.window_vals,new_window_vals))
544+
self.prev_epoch_X = np.stack((self.prev_epoch_X,new_X))
545+
# Throw away indices if the window_vals are below old_window_thresh
546+
retained_inds = [x for x in range(len(self.window_vals)) if self.window_vals[x] >= old_window_thresh]
547+
self.prev_epoch_idx = list(np.array(self.prev_epoch_idx)[retained_inds])
548+
self.window_vals = self.window_vals[retained_inds]
549+
self.prev_epoch_X = self.prev_epoch_X[retained_inds]
550+
551+
else:
552+
raise ValueError(
553+
"Choice of sample_mode is not supported."
554+
)
555+
556+
557+
def remove_landmarks(self):
558+
self.prev_epoch_X = None
559+
476560
def to_ONNX(self, save_location):
477561
"""Exports trained parametric UMAP as ONNX."""
478562
# Extract encoder

0 commit comments

Comments
 (0)