-
Couldn't load subscription status.
- Fork 390
Caiman PyTorch -> Ver 1.0.1 #1535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
fef0086
058cc7f
b0a2a46
09b53e1
9de6c5a
4addb54
94e2d39
0539ff7
379e10c
f27b711
036a7c1
f765314
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| #!/usr/bin/env python | ||
| """ | ||
| Contains the model architecture for cnn_model.pkl and cnn_model_online.pkl. The files | ||
| cnn_model.pkl and cnn_model_online.pkl contain the model weights. The weight files are | ||
| used to load the weights into the model architecture. | ||
| """ | ||
|
|
||
| import os | ||
| os.environ["KERAS_BACKEND"] = "torch" | ||
|
|
||
| try: | ||
| import keras_core as keras | ||
| except ImportError: | ||
| import keras | ||
|
|
||
|
|
||
| def keras_cnn_model_from_pickle(pickle_data, keras): | ||
| """Build a Keras model from pickle data format using Functional API.""" | ||
| try: | ||
| # Use Functional API which is more reliable for pre-loaded weights | ||
| inputs = keras.layers.Input(shape=(50, 50, 1), name='input_layer') | ||
|
|
||
| # Conv Block 1 | ||
| x = keras.layers.Conv2D(32, (3, 3), activation='relu', padding='valid', name='conv2d_20')(inputs) | ||
| x = keras.layers.Conv2D(32, (3, 3), activation='relu', padding='valid', name='conv2d_21')(x) | ||
| x = keras.layers.MaxPooling2D((2, 2), name='max_pooling2d_10')(x) | ||
| x = keras.layers.Dropout(0.25, name='dropout_15')(x) | ||
|
|
||
| # Conv Block 2 | ||
| x = keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', name='conv2d_22')(x) | ||
| x = keras.layers.Conv2D(64, (3, 3), activation='relu', padding='valid', name='conv2d_23')(x) | ||
| x = keras.layers.MaxPooling2D((2, 2), name='max_pooling2d_11')(x) | ||
| x = keras.layers.Dropout(0.25, name='dropout_16')(x) | ||
|
|
||
| # Dense Block | ||
| x = keras.layers.Flatten(name='flatten_5')(x) | ||
| x = keras.layers.Dense(512, activation='relu', name='dense_15')(x) | ||
| x = keras.layers.Dropout(0.5, name='dropout_17')(x) | ||
| outputs = keras.layers.Dense(2, activation='softmax', name='dense_16')(x) | ||
|
|
||
| # Create model | ||
| model = keras.Model(inputs=inputs, outputs=outputs, name='cnn_model') | ||
|
|
||
| # Build the model to initialize weights | ||
| model.build(input_shape=(None, 50, 50, 1)) | ||
|
|
||
| # Set weights from pickle data | ||
| if 'weights' in pickle_data: | ||
| weights = pickle_data['weights'] | ||
| if len(weights) == 12: # 6 layers × 2 weights each | ||
| # Get only trainable layers (skip dropout, pooling, flatten) | ||
| trainable_layers = [layer for layer in model.layers if len(layer.weights) > 0] | ||
|
|
||
| if len(trainable_layers) == 6: # Should be 4 conv + 2 dense | ||
| weight_idx = 0 | ||
| for layer in trainable_layers: | ||
| if len(layer.weights) == 2: # kernel and bias | ||
| kernel_weight = weights[weight_idx] | ||
| bias_weight = weights[weight_idx + 1] | ||
| layer.set_weights([kernel_weight, bias_weight]) | ||
| weight_idx += 2 | ||
| else: | ||
| # Fallback: set all weights at once | ||
| model.set_weights(weights) | ||
| else: | ||
| raise ValueError(f"Expected 12 weight arrays, got {len(weights)}") | ||
| else: | ||
| raise ValueError("No weights found in pickle data") | ||
|
|
||
| return model | ||
|
|
||
| except Exception as e: | ||
| raise ValueError(f"Failed to build Keras model from pickle: {e}") | ||
|
|
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,14 +21,14 @@ | |
| from multiprocessing import cpu_count | ||
| import numpy as np | ||
| import os | ||
| import pickle | ||
| from scipy.ndimage import percentile_filter | ||
| from scipy.sparse import coo_matrix, csc_matrix, spdiags, hstack | ||
| from scipy.stats import norm | ||
| from sklearn.decomposition import NMF | ||
| from skimage.morphology import disk | ||
| from sklearn.preprocessing import normalize | ||
| import torch | ||
| from torch.utils.data import DataLoader, TensorDataset | ||
| # Removed PyTorch imports since we're using Keras with PyTorch backend now | ||
| from time import time | ||
|
|
||
| import caiman | ||
|
|
@@ -40,7 +40,8 @@ | |
| high_pass_filter_space, sliding_window, | ||
| register_translation_3d, apply_shifts_dft) | ||
| import caiman.paths | ||
| from caiman.pytorch_model_arch import PyTorchCNN | ||
| from caiman.paths import caiman_datadir | ||
| from caiman.keras_model_arch import keras_cnn_model_from_pickle | ||
| from caiman.source_extraction.cnmf.cnmf import CNMF | ||
| from caiman.source_extraction.cnmf.estimates import Estimates | ||
| from caiman.source_extraction.cnmf.initialization import imblur, initialize_components, hals, downscale | ||
|
|
@@ -359,11 +360,26 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None): | |
| loaded_model = None | ||
| self.params.set('online', {'sniper_mode': False}) | ||
| else: | ||
| logger.info('Using Torch') | ||
| path = self.params.get('online', 'path_to_model').split(".")[:-1] | ||
| model_path = '.'.join(path + ['pt']) | ||
| loaded_model = PyTorchCNN() | ||
| loaded_model.load_state_dict(torch.load(model_path)) | ||
| os.environ["KERAS_BACKEND"] = "torch" | ||
| try: | ||
| import keras_core as keras | ||
| except ImportError: | ||
| import keras | ||
|
|
||
| logger.info('Using Keras with PyTorch backend') | ||
| model_name = self.params.get('online', 'path_to_model').split(".")[0] # Remove extension | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wondering if it might be better to have model_name have no path elements, and then build the $CAIMAN_DATA/model/$MODEL_NAME full path from that right before load Although the right load logic around that might be hard to get right; provided this works I can adjust things in this direction later |
||
|
|
||
| if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".pkl")): | ||
| with open(os.path.join(caiman_datadir(), model_name + ".pkl"), 'rb') as f: | ||
| pickle_data = pickle.load(f) | ||
| elif os.path.isfile(model_name + ".pkl"): | ||
| with open(model_name + ".pkl", 'rb') as f: | ||
| pickle_data = pickle.load(f) | ||
| else: | ||
| raise FileNotFoundError(f"File for requested model {model_name} not found") | ||
|
|
||
| logger.info(f"USING MODEL (Keras from Pickle)") | ||
| loaded_model = keras_cnn_model_from_pickle(pickle_data, keras) | ||
|
|
||
| self.loaded_model = loaded_model | ||
|
|
||
|
|
@@ -2108,21 +2124,10 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5), | |
| Ain2 = np.reshape(Ain2,(-1,) + tuple(np.diff(ijSig_cnn).squeeze()),order= 'F') | ||
| Ain2 = np.stack([cv2.resize(ain,(patch_size ,patch_size)) for ain in Ain2]) | ||
|
|
||
| final_crops = Ain2[:, :, :, np.newaxis] | ||
| final_crops_tensor = torch.tensor(final_crops, dtype=torch.float32).permute(0, 3, 1, 2) | ||
|
|
||
| #Create DataLoader for batching | ||
| dataset = TensorDataset(final_crops_tensor) | ||
| loader = DataLoader(dataset, batch_size=int(min_num_trial), shuffle=False) | ||
|
|
||
| loaded_model.eval() | ||
| all_predictions = [] | ||
| with torch.no_grad(): | ||
| for batch in loader: | ||
| outputs = loaded_model(batch[0]) | ||
| all_predictions.append(outputs) | ||
|
|
||
| predictions = torch.cat(all_predictions).cpu().numpy() | ||
| final_crops = Ain2[:, :, :, np.newaxis] # Keep in Keras format (BHWC) | ||
|
|
||
| # Use Keras model prediction instead of PyTorch | ||
| predictions = loaded_model.predict(final_crops, batch_size=int(min_num_trial)) | ||
| keep_cnn = list(np.where(predictions[:,0] > thresh_CNN_noisy)[0]) | ||
| cnn_pos = Ain2[keep_cnn] | ||
| else: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| #!/usr/bin/env python | ||
|
|
||
| import numpy as np | ||
| import os | ||
| import pickle | ||
|
|
||
| from caiman.keras_model_arch import keras_cnn_model_from_pickle | ||
| from caiman.paths import caiman_datadir | ||
|
|
||
| os.environ["KERAS_BACKEND"] = "torch" | ||
| try: | ||
| import keras_core as keras | ||
| except ImportError: | ||
| import keras | ||
|
|
||
| def test_keras(): | ||
| os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | ||
|
|
||
| try: | ||
| model_name = os.path.join(caiman_datadir(), 'model', 'cnn_model') | ||
| model_file = model_name + ".pkl" | ||
| with open(model_file, 'rb') as f: | ||
| print('USING MODEL:' + model_file) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. print(f"Using model {model_file") |
||
| pickle_data = pickle.load(f) | ||
|
|
||
| loaded_model = keras_cnn_model_from_pickle(pickle_data, keras) | ||
| except: | ||
| raise Exception(f'NN model could not be loaded') | ||
|
|
||
| A = np.random.randn(10, 50, 50, 1) | ||
| try: | ||
| predictions = loaded_model.predict(A, batch_size=32) | ||
| except: | ||
| raise Exception('NN model could not be deployed. use_keras = ' + str(use_keras)) | ||
|
|
||
| if __name__ == "__main__": | ||
| test_keras() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be good to move this to being a static import; dynamic imports are hard to reason about.