Skip to content
2 changes: 1 addition & 1 deletion caiman/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from caiman.base.movies import movie, load, load_movie_chain, _load_behavior, play_movie
from caiman.base.timeseries import concatenate
from caiman.cluster import start_server, stop_server
from caiman.keras_model_arch import keras_cnn_model_from_pickle
from caiman.mmapping import load_memmap, save_memmap, save_memmap_each, save_memmap_join
from caiman.pytorch_model_arch import PyTorchCNN
from caiman.summary_images import local_correlations

__version__ = importlib.metadata.version('caiman')
37 changes: 20 additions & 17 deletions caiman/components_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import os
import peakutils
import pickle
import scipy
from scipy.sparse import csc_matrix
from scipy.stats import norm
Expand All @@ -14,8 +15,8 @@
import warnings

import caiman
from caiman.keras_model_arch import keras_cnn_model_from_pickle
from caiman.paths import caiman_datadir
from caiman.pytorch_model_arch import PyTorchCNN
import caiman.utils.stats

try:
Expand Down Expand Up @@ -274,21 +275,30 @@ def evaluate_components_CNN(A,
logger.info("GPU run not requested, disabling use of GPUs")
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

os.environ["KERAS_BACKEND"] = "torch"
try:
import keras_core as keras
except ImportError:
import keras
if model_name is None:
model_name = os.path.join(caiman_datadir(), 'model', 'cnn_model')

logger.info('Using Torch')

logger.info('Using Keras 3.0 with PyTorch backend')

if loaded_model is None:
if os.path.isfile(os.path.join(caiman_datadir(), 'model', 'pytorch-models', model_name + ".pt")):
model_file = os.path.join(caiman_datadir(), 'model', 'pytorch-models', model_name + ".pt")
elif os.path.isfile(model_name + ".pt"):
model_file = model_name + ".pt"
if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".pkl")):
with open(os.path.join(caiman_datadir(), model_name + ".pkl"), 'rb') as f:
pickle_data = pickle.load(f)
elif os.path.isfile(model_name + ".pkl"):
with open(model_name + ".pkl", 'rb') as f:
pickle_data = pickle.load(f)
else:
raise FileNotFoundError(f"File for requested model {model_name} not found")
logger.info(f"Using model: {model_file}")
loaded_model = PyTorchCNN()
loaded_model.load_state_dict(torch.load(model_file))

logger.info(f"USING MODEL (Keras 3.0 API from Pickle)")
loaded_model = keras_cnn_model_from_pickle(pickle_data, keras)

logger.debug("Loaded model from disk")

Expand All @@ -302,16 +312,9 @@ def evaluate_components_CNN(A,
half_crop[1]:com[1] + half_crop[1]] for mm, com in zip(A.tocsc().T, coms)
]
final_crops = np.array([cv2.resize(im / np.linalg.norm(im), (patch_size, patch_size)) for im in crop_imgs])

# Numpy to PyTorch and add a channel dimension using unsqueeze
final_crops = torch.tensor(final_crops, dtype=torch.float32).unsqueeze(1)

# Pass the preprocessed image crops through the model to get predictions
with torch.no_grad():
predictions = loaded_model(final_crops)
predictions = loaded_model.predict(final_crops[:, :, :, np.newaxis], batch_size=32, verbose=1)

predictions_numpy = predictions.cpu().numpy()
return predictions_numpy, final_crops
return predictions, final_crops

def evaluate_components(Y: np.ndarray,
traces: np.ndarray,
Expand Down
74 changes: 74 additions & 0 deletions caiman/keras_model_arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#!/usr/bin/env python
"""
Contains the model architecture for cnn_model.pkl and cnn_model_online.pkl. The files
cnn_model.pkl and cnn_model_online.pkl contain the model weights. The weight files are
used to load the weights into the model architecture.
"""

import os
os.environ["KERAS_BACKEND"] = "torch"

try:
import keras_core as keras
except ImportError:
import keras


def keras_cnn_model_from_pickle(pickle_data, keras):
"""Build a Keras model from pickle data format using Functional API."""
try:
# Use Functional API which is more reliable for pre-loaded weights
inputs = keras.layers.Input(shape=(50, 50, 1), name='input_layer')

# Conv Block 1
x = keras.layers.Conv2D(32, (3, 3), activation='relu', padding='valid', name='conv2d_20')(inputs)
x = keras.layers.Conv2D(32, (3, 3), activation='relu', padding='valid', name='conv2d_21')(x)
x = keras.layers.MaxPooling2D((2, 2), name='max_pooling2d_10')(x)
x = keras.layers.Dropout(0.25, name='dropout_15')(x)

# Conv Block 2
x = keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', name='conv2d_22')(x)
x = keras.layers.Conv2D(64, (3, 3), activation='relu', padding='valid', name='conv2d_23')(x)
x = keras.layers.MaxPooling2D((2, 2), name='max_pooling2d_11')(x)
x = keras.layers.Dropout(0.25, name='dropout_16')(x)

# Dense Block
x = keras.layers.Flatten(name='flatten_5')(x)
x = keras.layers.Dense(512, activation='relu', name='dense_15')(x)
x = keras.layers.Dropout(0.5, name='dropout_17')(x)
outputs = keras.layers.Dense(2, activation='softmax', name='dense_16')(x)

# Create model
model = keras.Model(inputs=inputs, outputs=outputs, name='cnn_model')

# Build the model to initialize weights
model.build(input_shape=(None, 50, 50, 1))

# Set weights from pickle data
if 'weights' in pickle_data:
weights = pickle_data['weights']
if len(weights) == 12: # 6 layers × 2 weights each
# Get only trainable layers (skip dropout, pooling, flatten)
trainable_layers = [layer for layer in model.layers if len(layer.weights) > 0]

if len(trainable_layers) == 6: # Should be 4 conv + 2 dense
weight_idx = 0
for layer in trainable_layers:
if len(layer.weights) == 2: # kernel and bias
kernel_weight = weights[weight_idx]
bias_weight = weights[weight_idx + 1]
layer.set_weights([kernel_weight, bias_weight])
weight_idx += 2
else:
# Fallback: set all weights at once
model.set_weights(weights)
else:
raise ValueError(f"Expected 12 weight arrays, got {len(weights)}")
else:
raise ValueError("No weights found in pickle data")

return model

except Exception as e:
raise ValueError(f"Failed to build Keras model from pickle: {e}")

52 changes: 0 additions & 52 deletions caiman/pytorch_model_arch.py

This file was deleted.

51 changes: 28 additions & 23 deletions caiman/source_extraction/cnmf/online_cnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
from multiprocessing import cpu_count
import numpy as np
import os
import pickle
from scipy.ndimage import percentile_filter
from scipy.sparse import coo_matrix, csc_matrix, spdiags, hstack
from scipy.stats import norm
from sklearn.decomposition import NMF
from skimage.morphology import disk
from sklearn.preprocessing import normalize
import torch
from torch.utils.data import DataLoader, TensorDataset
# Removed PyTorch imports since we're using Keras with PyTorch backend now
from time import time

import caiman
Expand All @@ -40,7 +40,8 @@
high_pass_filter_space, sliding_window,
register_translation_3d, apply_shifts_dft)
import caiman.paths
from caiman.pytorch_model_arch import PyTorchCNN
from caiman.paths import caiman_datadir
from caiman.keras_model_arch import keras_cnn_model_from_pickle
from caiman.source_extraction.cnmf.cnmf import CNMF
from caiman.source_extraction.cnmf.estimates import Estimates
from caiman.source_extraction.cnmf.initialization import imblur, initialize_components, hals, downscale
Expand Down Expand Up @@ -359,11 +360,26 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None):
loaded_model = None
self.params.set('online', {'sniper_mode': False})
else:
logger.info('Using Torch')
path = self.params.get('online', 'path_to_model').split(".")[:-1]
model_path = '.'.join(path + ['pt'])
loaded_model = PyTorchCNN()
loaded_model.load_state_dict(torch.load(model_path))
os.environ["KERAS_BACKEND"] = "torch"
try:
import keras_core as keras
except ImportError:
import keras
Comment on lines +363 to +367
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be good to move this to being a static import; dynamic imports are hard to reason about.


logger.info('Using Keras with PyTorch backend')
model_name = self.params.get('online', 'path_to_model').split(".")[0] # Remove extension
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if it might be better to have model_name have no path elements, and then build the $CAIMAN_DATA/model/$MODEL_NAME full path from that right before load

Although the right load logic around that might be hard to get right; provided this works I can adjust things in this direction later


if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".pkl")):
with open(os.path.join(caiman_datadir(), model_name + ".pkl"), 'rb') as f:
pickle_data = pickle.load(f)
elif os.path.isfile(model_name + ".pkl"):
with open(model_name + ".pkl", 'rb') as f:
pickle_data = pickle.load(f)
else:
raise FileNotFoundError(f"File for requested model {model_name} not found")

logger.info(f"USING MODEL (Keras from Pickle)")
loaded_model = keras_cnn_model_from_pickle(pickle_data, keras)

self.loaded_model = loaded_model

Expand Down Expand Up @@ -2108,21 +2124,10 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5),
Ain2 = np.reshape(Ain2,(-1,) + tuple(np.diff(ijSig_cnn).squeeze()),order= 'F')
Ain2 = np.stack([cv2.resize(ain,(patch_size ,patch_size)) for ain in Ain2])

final_crops = Ain2[:, :, :, np.newaxis]
final_crops_tensor = torch.tensor(final_crops, dtype=torch.float32).permute(0, 3, 1, 2)

#Create DataLoader for batching
dataset = TensorDataset(final_crops_tensor)
loader = DataLoader(dataset, batch_size=int(min_num_trial), shuffle=False)

loaded_model.eval()
all_predictions = []
with torch.no_grad():
for batch in loader:
outputs = loaded_model(batch[0])
all_predictions.append(outputs)

predictions = torch.cat(all_predictions).cpu().numpy()
final_crops = Ain2[:, :, :, np.newaxis] # Keep in Keras format (BHWC)

# Use Keras model prediction instead of PyTorch
predictions = loaded_model.predict(final_crops, batch_size=int(min_num_trial))
keep_cnn = list(np.where(predictions[:,0] > thresh_CNN_noisy)[0])
cnn_pos = Ain2[keep_cnn]
else:
Expand Down
4 changes: 2 additions & 2 deletions caiman/source_extraction/cnmf/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def __init__(self, fnames=None, dims=None, dxy=(1, 1),
opencv_codec: str, default: 'H264'
FourCC video codec for saving movie. Check http://www.fourcc.org/codecs.php

path_to_model: str, default: os.path.join(caiman_datadir(), 'model', 'cnn_model_online.h5')
path_to_model: str, default: os.path.join(caiman_datadir(), 'model', 'cnn_model_online.pkl')
Path to online CNN classifier

ring_CNN:
Expand Down Expand Up @@ -848,7 +848,7 @@ def __init__(self, fnames=None, dims=None, dxy=(1, 1),
'num_times_comp_updated': num_times_comp_updated,
'opencv_codec': 'H264', # FourCC video codec for saving movie. Check http://www.fourcc.org/codecs.php
'path_to_model': os.path.join(caiman_datadir(), 'model',
'cnn_model_online.h5'),
'cnn_model_online.pkl'),
'ring_CNN': False, # flag for using a ring CNN background model
'rval_thr': rval_thr, # space correlation threshold
'save_online_movie': False, # flag for saving online movie
Expand Down
37 changes: 37 additions & 0 deletions caiman/tests/test_keras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/usr/bin/env python

import numpy as np
import os
import pickle

from caiman.keras_model_arch import keras_cnn_model_from_pickle
from caiman.paths import caiman_datadir

os.environ["KERAS_BACKEND"] = "torch"
try:
import keras_core as keras
except ImportError:
import keras

def test_keras():
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

try:
model_name = os.path.join(caiman_datadir(), 'model', 'cnn_model')
model_file = model_name + ".pkl"
with open(model_file, 'rb') as f:
print('USING MODEL:' + model_file)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print(f"Using model {model_file")

pickle_data = pickle.load(f)

loaded_model = keras_cnn_model_from_pickle(pickle_data, keras)
except:
raise Exception(f'NN model could not be loaded')

A = np.random.randn(10, 50, 50, 1)
try:
predictions = loaded_model.predict(A, batch_size=32)
except:
raise Exception('NN model could not be deployed. use_keras = ' + str(use_keras))

if __name__ == "__main__":
test_keras()
2 changes: 1 addition & 1 deletion caiman/tests/test_mrcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,4 @@ class InferenceConfig(Config):
print("\nTest passed successfully!")

if __name__ == "__main__":
test_mrcnn_pytorch()
test_mrcnn_pytorch()
Loading
Loading