Skip to content

Optimization with Huge Speedup: Dask Dataframes Instead of Manual Dataframe Iterations #240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 85 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
5de22c6
init
selmanozleyen May 20, 2025
7332e5c
add tests
selmanozleyen May 20, 2025
0be2e68
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 20, 2025
1429cab
resolve comments
selmanozleyen Jun 12, 2025
d850771
Squashed commit of the following:
selmanozleyen Jun 13, 2025
6f0d747
Revert "resolve comments"
selmanozleyen Jun 13, 2025
d856def
Revert "Revert "resolve comments""
selmanozleyen Jun 13, 2025
79f0b0d
Revert "Squashed commit of the following:"
selmanozleyen Jun 13, 2025
36cc534
fix TokenAttention
MUCDK May 20, 2025
68262d3
fix tests
MUCDK May 20, 2025
a904bc3
fix tests
MUCDK May 20, 2025
a37d659
revert previous precommit changes
MUCDK May 20, 2025
c153919
update scvi dependency
MUCDK May 22, 2025
517b346
skip gene emb tests
MUCDK May 22, 2025
1e59976
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2025
8379630
skip more tests
MUCDK May 22, 2025
b1566b5
skip more tests
MUCDK May 22, 2025
c34ceb8
skip more tests
MUCDK May 23, 2025
20c80a2
skip more tests
MUCDK May 23, 2025
47557f5
enable some more tests
MUCDK May 23, 2025
80cba8e
deprecate 3.10
MUCDK May 23, 2025
d2c3ecf
deprecate 3.10
MUCDK May 23, 2025
a8d64b6
deprecate 3.10
MUCDK May 23, 2025
47013a7
enable more tests again
MUCDK May 23, 2025
2011137
skip some tests again
MUCDK May 23, 2025
ea05e1a
only enable test_cellflow_with_validation
MUCDK May 23, 2025
3f593b9
enable all tests
MUCDK May 23, 2025
8f3256a
add combosciplex
MUCDK May 9, 2025
d6c71b1
add combosciplex
MUCDK May 23, 2025
5e781ef
Add solver to callback call
LeonStadelmann Apr 21, 2025
ec367af
Add typing and solver for on_train_end
LeonStadelmann Apr 24, 2025
2edf07e
Added test for custom callbacks
LeonStadelmann May 1, 2025
bfb8230
Rename flow
LeonStadelmann May 1, 2025
4ef9dc7
Added source data parameter + stochastic test
LeonStadelmann May 12, 2025
856e579
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 12, 2025
786f03f
fix docs
LeonStadelmann May 17, 2025
261af6b
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] May 19, 2025
0224332
pass rng from cellflow.predict
MUCDK May 26, 2025
aa5f02e
fix callback/
MUCDK May 26, 2025
2c679f2
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] May 26, 2025
5c2db55
init
selmanozleyen May 20, 2025
0a2ee61
support both dataloaders
selmanozleyen May 20, 2025
64ee583
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 20, 2025
365c5d6
throw error when there isn't any valid indices
selmanozleyen Jun 12, 2025
4f42290
add tests
selmanozleyen Jun 12, 2025
d636ecb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 12, 2025
6186113
remove unused fixture
selmanozleyen Jun 12, 2025
1ad1637
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 12, 2025
23eb759
document better
selmanozleyen Jun 12, 2025
cf85305
update error message
selmanozleyen Jun 12, 2025
ebe9506
add return types
selmanozleyen Jun 12, 2025
efade17
init
selmanozleyen May 13, 2025
5799c87
condition keys left to add
selmanozleyen May 13, 2025
f2b74ed
needs testing
selmanozleyen May 13, 2025
0e8016b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 13, 2025
1f054e8
update implementation so that two functions predict and predict batch…
selmanozleyen May 15, 2025
ae04ce0
Adjust docs
LeonStadelmann May 17, 2025
014122e
Add batched predict test
LeonStadelmann May 18, 2025
61efd1e
Handle empty input
LeonStadelmann May 21, 2025
50be97b
add genot
selmanozleyen May 21, 2025
befcb42
Add testing for genot predict
LeonStadelmann May 23, 2025
2e9792e
emove duplicate
LeonStadelmann May 23, 2025
2a2d1cd
revert duplicate removal
LeonStadelmann May 23, 2025
1e7e21e
Replace tree map
LeonStadelmann May 27, 2025
44de602
Fix type error
LeonStadelmann Jun 1, 2025
8986ce2
clarify documentation that this requires same number of cells
selmanozleyen Jun 12, 2025
415088d
When using PredictionSampler we can't use batched mode
selmanozleyen Jun 12, 2025
ece4b75
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 12, 2025
34ea85c
init
selmanozleyen May 20, 2025
b12bbcd
Merge branch 'main' into refactor/get-cond-perf
selmanozleyen Jun 13, 2025
7608635
add dask as dependency to see which tests fail
selmanozleyen Jun 15, 2025
83050d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 15, 2025
4eb4aa9
submit working version
selmanozleyen Jun 17, 2025
dddd0fb
remove get_condition_old
selmanozleyen Jun 17, 2025
939b054
also remove function itself
selmanozleyen Jun 17, 2025
c554c82
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2025
b3e3122
partially working version
selmanozleyen Jun 17, 2025
dd2886b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2025
6ef8c5e
version closest to working so far
selmanozleyen Jun 18, 2025
32d30b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2025
57cc3a4
passes except condition_data
selmanozleyen Jun 19, 2025
80f991f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2025
4e60fb2
some state
selmanozleyen Jun 19, 2025
b73c443
save state before working on changing something
selmanozleyen Jun 20, 2025
8ed5793
precommit
selmanozleyen Jun 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"anndata",
"cloudpickle",
"coverage",
"dask",
"diffrax",
"flax",
"orbax",
Expand Down
4 changes: 2 additions & 2 deletions src/cellflow/data/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ class ValidationData(BaseDataMixin):
int, tuple[str, ...]
] # (n_targets,), dictionary explaining perturbation_covariates_mask
perturbation_idx_to_id: dict[int, Any]
condition_data: dict[str, ArrayLike] # (n_targets,) all embeddings for conditions
control_to_perturbation: dict[int, jax.Array] # mapping from control idx to target distribution idcs
condition_data: dict[str, np.ndarray] # (n_targets,) all embeddings for conditions
control_to_perturbation: dict[int, np.ndarray] # mapping from control idx to target distribution idcs
max_combination_length: int
null_value: Any
data_manager: Any
Expand Down
5 changes: 2 additions & 3 deletions src/cellflow/data/_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any, Literal

import jax
import jax.numpy as jnp
import numpy as np

from cellflow.data._data import PredictionData, TrainingData, ValidationData
Expand Down Expand Up @@ -111,14 +110,14 @@ def _get_key(self, cond_idx: int) -> tuple[str, ...]:
cov_combination = self._data.perturbation_idx_to_covariates[cond_idx] # type: ignore[attr-defined]
return tuple(cov_combination[i] for i in range(len(cov_combination)))

def _get_perturbation_to_control(self, data: ValidationData | PredictionData) -> dict[int, int]:
def _get_perturbation_to_control(self, data: ValidationData | PredictionData) -> dict[int, np.ndarray]:
d = {}
for k, v in data.control_to_perturbation.items():
for el in v:
d[el] = k
return d

def _get_condition_data(self, cond_idx: int) -> jnp.ndarray:
def _get_condition_data(self, cond_idx: int) -> dict[str, np.ndarray]:
return {k: v[[cond_idx], ...] for k, v in self._data.condition_data.items()} # type: ignore[attr-defined]


Expand Down
654 changes: 513 additions & 141 deletions src/cellflow/data/_datamanager.py

Large diffs are not rendered by default.

358 changes: 358 additions & 0 deletions tests.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def adata_perturbation() -> ad.AnnData:
for drug in adata.obs["drug1"].cat.categories:
drug_emb[drug] = np.random.randn(5, 1)
adata.uns["drug"] = drug_emb

print(adata.uns["drug"])
cell_type_emb = {}
for cell_type in adata.obs["cell_type"].cat.categories:
cell_type_emb[cell_type] = np.random.randn(3, 1)
Expand Down
34 changes: 17 additions & 17 deletions tests/data/test_datamanager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import anndata as ad
import jax
import numpy as np
import pytest

from cellflow.data._datamanager import DataManager
Expand Down Expand Up @@ -168,6 +168,7 @@ def test_get_train_data(
perturbation_covariate_reps=perturbation_covariate_reps,
sample_covariates=sample_covariates,
)

assert isinstance(dm, DataManager)
assert dm._sample_rep == sample_rep
assert dm._control_key == "control"
Expand All @@ -185,18 +186,18 @@ def test_get_train_data(
assert train_data.n_controls == len(adata_perturbation.obs["cell_type"].cat.categories)

assert isinstance(train_data.condition_data, dict)
assert isinstance(list(train_data.condition_data.values())[0], jax.Array)
assert isinstance(list(train_data.condition_data.values())[0], np.ndarray)
assert train_data.max_combination_length == 1

if sample_covariates == [] and perturbation_covariates == {"drug": ("drug1",)}:
assert (
train_data.n_perturbations
== (len(adata_perturbation.obs["drug1"].cat.categories) - 1) * train_data.n_controls
)
assert isinstance(train_data.cell_data, jax.Array)
assert isinstance(train_data.split_covariates_mask, jax.Array)
assert isinstance(train_data.cell_data, np.ndarray)
assert isinstance(train_data.split_covariates_mask, np.ndarray)
assert isinstance(train_data.split_idx_to_covariates, dict)
assert isinstance(train_data.perturbation_covariates_mask, jax.Array)
assert isinstance(train_data.perturbation_covariates_mask, np.ndarray)
assert isinstance(train_data.perturbation_idx_to_covariates, dict)
assert isinstance(train_data.control_to_perturbation, dict)

Expand All @@ -222,7 +223,6 @@ def test_get_train_data_with_combinations(
sample_covariates=["cell_type"],
sample_covariate_reps={"cell_type": "cell_type"},
)

train_data = dm.get_train_data(adata_perturbation)

assert ((train_data.perturbation_covariates_mask == -1) + (train_data.split_covariates_mask == -1)).all()
Expand All @@ -233,7 +233,7 @@ def test_get_train_data_with_combinations(
assert train_data.n_controls == len(adata_perturbation.obs["cell_type"].cat.categories)

assert isinstance(train_data.condition_data, dict)
assert isinstance(list(train_data.condition_data.values())[0], jax.Array)
assert isinstance(list(train_data.condition_data.values())[0], np.ndarray)
assert train_data.max_combination_length == len(perturbation_covariates["drug"])

for k in perturbation_covariates.keys():
Expand All @@ -253,10 +253,10 @@ def test_get_train_data_with_combinations(
cov_name = adata_perturbation.obs[cov_key].values[0]
assert train_data.condition_data[v].shape[2] == adata_perturbation.uns[k][cov_name].shape[0]

assert isinstance(train_data.cell_data, jax.Array)
assert isinstance(train_data.split_covariates_mask, jax.Array)
assert isinstance(train_data.cell_data, np.ndarray)
assert isinstance(train_data.split_covariates_mask, np.ndarray)
assert isinstance(train_data.split_idx_to_covariates, dict)
assert isinstance(train_data.perturbation_covariates_mask, jax.Array)
assert isinstance(train_data.perturbation_covariates_mask, np.ndarray)
assert isinstance(train_data.perturbation_idx_to_covariates, dict)
assert isinstance(train_data.control_to_perturbation, dict)

Expand Down Expand Up @@ -319,16 +319,16 @@ def test_get_validation_data(

val_data = dm.get_validation_data(adata_perturbation)

assert isinstance(val_data.cell_data, jax.Array)
assert isinstance(val_data.split_covariates_mask, jax.Array)
assert isinstance(val_data.cell_data, np.ndarray)
assert isinstance(val_data.split_covariates_mask, np.ndarray)
assert isinstance(val_data.split_idx_to_covariates, dict)
assert isinstance(val_data.perturbation_covariates_mask, jax.Array)
assert isinstance(val_data.perturbation_covariates_mask, np.ndarray)
assert isinstance(val_data.perturbation_idx_to_covariates, dict)
assert isinstance(val_data.control_to_perturbation, dict)
assert val_data.max_combination_length == len(perturbation_covariates["drug"])

assert isinstance(val_data.condition_data, dict)
assert isinstance(list(val_data.condition_data.values())[0], jax.Array)
assert isinstance(list(val_data.condition_data.values())[0], np.ndarray)

if sample_covariates == [] and perturbation_covariates == {"drug": ("drug1",)}:
assert (
Expand Down Expand Up @@ -399,15 +399,15 @@ def test_get_prediction_data(
adata_pred.obs["control"] = True
pred_data = dm.get_prediction_data(adata_pred, covariate_data=adata_pred.obs, sample_rep=sample_rep)

assert isinstance(pred_data.cell_data, jax.Array)
assert isinstance(pred_data.split_covariates_mask, jax.Array)
assert isinstance(pred_data.cell_data, np.ndarray)
assert isinstance(pred_data.split_covariates_mask, np.ndarray)
assert isinstance(pred_data.split_idx_to_covariates, dict)
assert isinstance(pred_data.perturbation_idx_to_covariates, dict)
assert isinstance(pred_data.control_to_perturbation, dict)
assert pred_data.max_combination_length == len(perturbation_covariates["drug"])

assert isinstance(pred_data.condition_data, dict)
assert isinstance(list(pred_data.condition_data.values())[0], jax.Array)
assert isinstance(list(pred_data.condition_data.values())[0], np.ndarray)

if sample_covariates == [] and perturbation_covariates == {"drug": ("drug1",)}:
assert (
Expand Down
Loading
Loading