@@ -114,6 +114,8 @@ def __init__(
114114 self .global_correlation_loss_weight = global_correlation_loss_weight
115115 self .landmark_loss_fn = landmark_loss_fn
116116 self .landmark_loss_weight = landmark_loss_weight
117+ self .prev_epoch_X = None
118+ self .window_vals = None
117119
118120 self .reconstruction_validation = (
119121 reconstruction_validation # holdout data for reconstruction acc
@@ -174,6 +176,16 @@ def fit(self, X, y=None, precomputed_distances=None, landmark_positions=None):
174176 The desired position in low-dimensional space of each sample in X.
175177 Points that are not landmarks should have nan coordinates.
176178 """
179+ if (self .prev_epoch_X is not None )& (landmark_positions is None ):
180+ # Add the landmark points for training, then make a landmark vector. NaN corresponds to no landmark information.
181+ landmark_positions = np .stack (
182+ [np .array ([np .nan , np .nan ])]* X .shape [0 ] + list (
183+ self .transform (
184+ self .prev_epoch_X
185+ )
186+ )
187+ )
188+ X = np .concatenate ((X , self .prev_epoch_X ))
177189
178190 if landmark_positions is not None :
179191 len_X = len (X )
@@ -230,6 +242,16 @@ def fit_transform(
230242 The desired position in low-dimensional space of each sample in X.
231243 Points that are not landmarks should have nan coordinates.
232244 """
245+ if (self .prev_epoch_X is not None )& (landmark_positions is None ):
246+ # Add the landmark points for training, then make a landmark vector. NaN corresponds to no landmark information.
247+ landmark_positions = np .stack (
248+ [np .array ([np .nan , np .nan ])]* X .shape [0 ] + list (
249+ self .transform (
250+ self .prev_epoch_X
251+ )
252+ )
253+ )
254+ X = np .concatenate ((X , self .prev_epoch_X ))
233255
234256 if landmark_positions is not None :
235257 len_X = len (X )
@@ -473,6 +495,68 @@ def save(self, save_location, verbose=True):
473495 if verbose :
474496 print ("Pickle of ParametricUMAP model saved to {}" .format (model_output ))
475497
498+ def add_landmarks (self , X , sample_pct = 0.01 , sample_mode = "uniform" , landmark_loss_weight = 0.01 ,curr_window_vals = 1.0 , old_window_thresh = 0.0 ):
499+ """Add some points from a dataset X as "landmarks" to be approximately preserved after retraining.
500+
501+ Parameters
502+ ----------
503+ X : array, shape (n_samples, n_features)
504+ Old data to be retained.
505+ sample_pct : float, optional
506+ Percentage of old data to use as landmarks.
507+ sample_mode : str, optional
508+ Method for sampling points. Currently only "uniform" and "sliding_window" are supported.
509+ landmark_loss_weight : float, optional
510+ Multiplier for landmark loss function.
511+ curr_window_vals: array, shape (n_samples,) or float, optional
512+ In "sliding_window" mode, the window value to give to the current points.
513+ old_window_thresh: float, optional
514+ In "sliding_window" mode, points with values below this value are dropped.
515+
516+ """
517+ self .sample_pct = sample_pct
518+ self .sample_mode = sample_mode
519+ self .landmark_loss_weight = landmark_loss_weight
520+
521+ if self .sample_mode == "uniform" :
522+ self .prev_epoch_idx = list (np .random .choice (range (X .shape [0 ]), int (X .shape [0 ]* sample_pct ), replace = False ))
523+ self .prev_epoch_X = X [self .prev_epoch_idx ]
524+ elif self .sample_mode == "sliding_window" :
525+ if (self .window_vals is None )& (self .prev_epoch_X is not None ):
526+ raise ValueError (
527+ "Use remove_landmarks to remove previous landmarks before adding sliding windows."
528+ )
529+ if type (curr_window_vals ) is float :
530+ curr_window_vals = np .array ([curr_window_vals ]* X .shape [0 ])
531+ new_idx = list (np .random .choice (range (X .shape [0 ]), int (X .shape [0 ]* self .sample_pct ), replace = False ))
532+ new_X = X [new_idx ]
533+ new_window_vals = curr_window_vals [new_idx ]
534+ # update self.prev_epoch_idx, self.prev_epoch_X, self.window_vals by FIRST concatenating with the old values, THEN throwing away everything that fails the threshold.
535+ if self .window_vals is None :
536+ self .prev_epoch_idx = new_idx
537+ self .window_vals = new_window_vals
538+ self .prev_epoch_X = new_X
539+ else :
540+ print (self .prev_epoch_X .shape ) # ZZX: Kill this before release.
541+ print (len (new_idx ))
542+ self .prev_epoch_idx = self .prev_epoch_idx .extend (new_idx )
543+ self .window_vals = np .stack ((self .window_vals ,new_window_vals ))
544+ self .prev_epoch_X = np .stack ((self .prev_epoch_X ,new_X ))
545+ # Throw away indices if the window_vals are below old_window_thresh
546+ retained_inds = [x for x in range (len (self .window_vals )) if self .window_vals [x ] >= old_window_thresh ]
547+ self .prev_epoch_idx = list (np .array (self .prev_epoch_idx )[retained_inds ])
548+ self .window_vals = self .window_vals [retained_inds ]
549+ self .prev_epoch_X = self .prev_epoch_X [retained_inds ]
550+
551+ else :
552+ raise ValueError (
553+ "Choice of sample_mode is not supported."
554+ )
555+
556+
557+ def remove_landmarks (self ):
558+ self .prev_epoch_X = None
559+
476560 def to_ONNX (self , save_location ):
477561 """Exports trained parametric UMAP as ONNX."""
478562 # Extract encoder
0 commit comments