Skip to content

Commit

Permalink
Change type of var_names_g to np.ndarray (#202)
Browse files Browse the repository at this point in the history
* test pyro 1.9.1

* var_names_g: np.ndarray
  • Loading branch information
ordabayevy authored Jun 12, 2024
1 parent 430fc0f commit e8163ee
Show file tree
Hide file tree
Showing 16 changed files with 37 additions and 46 deletions.
6 changes: 2 additions & 4 deletions cellarium/ml/models/geneformer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Sequence

import numpy as np
import torch
from transformers import BertConfig, BertForMaskedLM
Expand Down Expand Up @@ -62,7 +60,7 @@ class Geneformer(CellariumModel, PredictMixin):

def __init__(
self,
var_names_g: Sequence[str],
var_names_g: np.ndarray,
hidden_size: int = 256,
num_hidden_layers: int = 6,
num_attention_heads: int = 4,
Expand All @@ -78,7 +76,7 @@ def __init__(
mlm_probability: float = 0.15,
) -> None:
super().__init__()
self.var_names_g = np.array(var_names_g)
self.var_names_g = var_names_g
# model configuration
config = {
"vocab_size": len(self.var_names_g) + 2, # number of genes + 2 for <mask> and <pad> tokens
Expand Down
5 changes: 2 additions & 3 deletions cellarium/ml/models/incremental_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: BSD-3-Clause

import math
from collections.abc import Sequence

import lightning.pytorch as pl
import numpy as np
Expand Down Expand Up @@ -44,13 +43,13 @@ class IncrementalPCA(CellariumModel, PredictMixin):

def __init__(
self,
var_names_g: Sequence[str],
var_names_g: np.ndarray,
n_components: int,
svd_lowrank_niter: int = 2,
perform_mean_correction: bool = False,
) -> None:
super().__init__()
self.var_names_g = np.array(var_names_g)
self.var_names_g = var_names_g
n_vars = len(self.var_names_g)
self.n_vars = n_vars
self.n_components = n_components
Expand Down
6 changes: 2 additions & 4 deletions cellarium/ml/models/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
# SPDX-License-Identifier: BSD-3-Clause


from collections.abc import Sequence

import lightning.pytorch as pl
import numpy as np
import pyro
Expand Down Expand Up @@ -41,7 +39,7 @@ class LogisticRegression(CellariumModel):
def __init__(
self,
n_obs: int,
var_names_g: Sequence[str],
var_names_g: np.ndarray,
n_categories: int,
W_prior_scale: float = 1.0,
W_init_scale: float = 1.0,
Expand All @@ -52,7 +50,7 @@ def __init__(

# data
self.n_obs = n_obs
self.var_names_g = np.array(var_names_g)
self.var_names_g = var_names_g
self.n_vars = len(var_names_g)
self.n_categories = n_categories

Expand Down
5 changes: 2 additions & 3 deletions cellarium/ml/models/onepass_mean_var_std.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Sequence
from typing import Literal

import lightning.pytorch as pl
Expand Down Expand Up @@ -32,9 +31,9 @@ class OnePassMeanVarStd(CellariumModel):
var_names_g: The variable names schema for the input data validation.
"""

def __init__(self, var_names_g: Sequence[str], algorithm: Literal["naive", "shifted_data"] = "naive") -> None:
def __init__(self, var_names_g: np.ndarray, algorithm: Literal["naive", "shifted_data"] = "naive") -> None:
super().__init__()
self.var_names_g = np.array(var_names_g)
self.var_names_g = var_names_g
n_vars = len(self.var_names_g)
self.n_vars = n_vars
self.algorithm = algorithm
Expand Down
5 changes: 2 additions & 3 deletions cellarium/ml/models/probabilistic_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: BSD-3-Clause


from collections.abc import Sequence
from typing import Literal

import numpy as np
Expand Down Expand Up @@ -56,7 +55,7 @@ class ProbabilisticPCA(CellariumModel, PredictMixin):
def __init__(
self,
n_obs: int,
var_names_g: Sequence[str],
var_names_g: np.ndarray,
n_components: int,
ppca_flavor: Literal["marginalized", "linear_vae"],
mean_g: torch.Tensor | None = None,
Expand All @@ -67,7 +66,7 @@ def __init__(
super().__init__()

self.n_obs = n_obs
self.var_names_g = np.array(var_names_g)
self.var_names_g = var_names_g
n_vars = len(self.var_names_g)
self.n_vars = n_vars
self.n_components = n_components
Expand Down
5 changes: 2 additions & 3 deletions cellarium/ml/models/tdigest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: BSD-3-Clause

import os
from collections.abc import Sequence
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -33,9 +32,9 @@ class TDigest(CellariumModel):
var_names_g: The variable names schema for the input data validation.
"""

def __init__(self, var_names_g: Sequence[str]) -> None:
def __init__(self, var_names_g: np.ndarray) -> None:
super().__init__()
self.var_names_g = np.array(var_names_g)
self.var_names_g = var_names_g
n_vars = len(self.var_names_g)
self.n_vars = n_vars
self.tdigests = [crick.tdigest.TDigest() for _ in range(self.n_vars)]
Expand Down
5 changes: 2 additions & 3 deletions cellarium/ml/transforms/divide_by_scale.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Sequence

import numpy as np
import torch
Expand Down Expand Up @@ -31,11 +30,11 @@ class DivideByScale(nn.Module):
A value added to the denominator for numerical stability.
"""

def __init__(self, scale_g: torch.Tensor, var_names_g: Sequence[str], eps: float = 1e-6) -> None:
def __init__(self, scale_g: torch.Tensor, var_names_g: np.ndarray, eps: float = 1e-6) -> None:
super().__init__()
self.scale_g: torch.Tensor
self.register_buffer("scale_g", scale_g)
self.var_names_g = np.array(var_names_g)
self.var_names_g = var_names_g
assert_nonnegative("eps", eps)
self.eps = eps

Expand Down
5 changes: 2 additions & 3 deletions cellarium/ml/transforms/z_score.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Sequence

import numpy as np
import torch
Expand Down Expand Up @@ -37,15 +36,15 @@ def __init__(
self,
mean_g: torch.Tensor,
std_g: torch.Tensor,
var_names_g: Sequence[str],
var_names_g: np.ndarray,
eps: float = 1e-6,
) -> None:
super().__init__()
self.mean_g: torch.Tensor
self.std_g: torch.Tensor
self.register_buffer("mean_g", mean_g)
self.register_buffer("std_g", std_g)
self.var_names_g = np.array(var_names_g)
self.var_names_g = var_names_g
assert_nonnegative("eps", eps)
self.eps = eps

Expand Down
2 changes: 1 addition & 1 deletion examples/cli_workflow/ipca_train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ model:
!CheckpointLoader
file_path: /tmp/test_examples/onepass/lightning_logs/version_0/checkpoints/epoch=0-step=2.ckpt
attr: model.var_names_g
convert_fn: numpy.ndarray.tolist
convert_fn: null
model:
class_path: cellarium.ml.models.IncrementalPCA
init_args:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def test_checkpoint_loader(tmp_path: Path) -> None:
!CheckpointLoader
file_path: {ckpt_path}
attr: model.var_names_g
convert_fn: numpy.ndarray.tolist
convert_fn: null
model: cellarium.ml.models.LogisticRegression
optim_fn: torch.optim.Adam
data:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_geneformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@

def test_load_from_checkpoint_multi_device(tmp_path: Path):
n, g = 4, 3
var_names_g = [f"gene_{i}" for i in range(g)]
var_names_g = np.array([f"gene_{i}" for i in range(g)])
devices = int(os.environ.get("TEST_DEVICES", "1"))
# dataloader
train_loader = torch.utils.data.DataLoader(
BoringDataset(
np.arange(n * g).reshape(n, g),
var_names=np.array(var_names_g),
var_names=var_names_g,
),
collate_fn=collate_fn,
)
Expand Down Expand Up @@ -63,7 +63,7 @@ def test_load_from_checkpoint_multi_device(tmp_path: Path):

@pytest.mark.parametrize("perturb", ["activation", "deletion", "map", "none"])
def test_tokenize_with_perturbations(perturb: str):
var_names_g = ["a", "b", "c", "d"]
var_names_g = np.array(["a", "b", "c", "d"])
geneformer = Geneformer(var_names_g=var_names_g)
x_ng = torch.tensor([[4, 3, 2, 1]]) # sort order will be [a,b,c,d] and tokens will be [2,3,4,5]

Expand Down
9 changes: 5 additions & 4 deletions tests/test_ipca.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,23 @@ def x_ng():
def test_incremental_pca_multi_device(x_ng: torch.Tensor, perform_mean_correction: bool, batch_size: int, k: int):
n, g = x_ng.shape
x_ng_centered = x_ng - x_ng.mean(dim=0)
var_names_g = np.array([f"gene_{i}" for i in range(g)])
devices = int(os.environ.get("TEST_DEVICES", "1"))
batch_size = batch_size // devices

# dataloader
train_loader = torch.utils.data.DataLoader(
BoringDataset(
(x_ng if perform_mean_correction else x_ng_centered).numpy(),
np.array([f"gene_{i}" for i in range(g)]),
var_names_g,
),
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn,
)
# model
ipca = IncrementalPCA(
var_names_g=[f"gene_{i}" for i in range(g)],
var_names_g=var_names_g,
n_components=k,
perform_mean_correction=perform_mean_correction,
)
Expand Down Expand Up @@ -88,13 +89,13 @@ def test_incremental_pca_multi_device(x_ng: torch.Tensor, perform_mean_correctio

def test_load_from_checkpoint_multi_device(tmp_path: Path):
n, g = 3, 2
var_names_g = [f"gene_{i}" for i in range(g)]
var_names_g = np.array([f"gene_{i}" for i in range(g)])
devices = int(os.environ.get("TEST_DEVICES", "1"))
# dataloader
train_loader = torch.utils.data.DataLoader(
BoringDataset(
np.random.randn(n, g),
np.array(var_names_g),
var_names_g,
),
collate_fn=collate_fn,
)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_onepass_mean_var_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ def test_onepass_mean_var_std_multi_device(
@pytest.mark.parametrize("algorithm", ["naive", "shifted_data"])
def test_load_from_checkpoint_multi_device(tmp_path: Path, algorithm: Literal["naive", "shifted_data"]):
n, g = 3, 2
var_names_g = [f"gene_{i}" for i in range(g)]
var_names_g = np.array([f"gene_{i}" for i in range(g)])
devices = int(os.environ.get("TEST_DEVICES", "1"))
# dataloader
train_loader = torch.utils.data.DataLoader(
BoringDataset(
np.random.randn(n, g),
np.array(var_names_g),
var_names_g,
),
collate_fn=collate_fn,
)
Expand Down Expand Up @@ -167,7 +167,7 @@ def test_accuracy(mean: float, dtype: torch.dtype, algorithm: Literal["naive", "
std = 0.1
x = mean + std * torch.randn(n_trials, dtype=dtype)

onepass = OnePassMeanVarStd(var_names_g=["x"], algorithm=algorithm)
onepass = OnePassMeanVarStd(var_names_g=np.array(["x"]), algorithm=algorithm)
for chunk in x.split(1000):
onepass(x_ng=chunk[:, None], var_names_g=["x"])

Expand Down
10 changes: 5 additions & 5 deletions tests/test_ppca.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_probabilistic_pca_multi_device(
s = np.sqrt(0.5 * total_var / g)
ppca = ProbabilisticPCA(
n_obs=n,
var_names_g=[f"gene_{i}" for i in range(g)],
var_names_g=np.array([f"gene_{i}" for i in range(g)]),
n_components=k,
ppca_flavor=ppca_flavor,
mean_g=x_mean_g,
Expand Down Expand Up @@ -116,12 +116,12 @@ def test_probabilistic_pca_multi_device(
def test_variance_monitor(x_ng: np.ndarray):
n, g = x_ng.shape
k = 3
var_names_g = [f"gene_{i}" for i in range(g)]
var_names_g = np.array([f"gene_{i}" for i in range(g)])
# dataloader
train_loader = torch.utils.data.DataLoader(
BoringDataset(
x_ng,
np.array(var_names_g),
var_names_g,
),
batch_size=n // 2,
collate_fn=collate_fn,
Expand All @@ -144,13 +144,13 @@ def test_variance_monitor(x_ng: np.ndarray):

def test_load_from_checkpoint_multi_device(tmp_path: Path):
n, g = 3, 2
var_names_g = [f"gene_{i}" for i in range(g)]
var_names_g = np.array([f"gene_{i}" for i in range(g)])
devices = int(os.environ.get("TEST_DEVICES", "1"))
# dataloader
train_loader = torch.utils.data.DataLoader(
BoringDataset(
np.random.randn(n, g).astype(np.float32),
np.array(var_names_g),
var_names_g,
),
collate_fn=collate_fn,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_tdigest.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@ def test_load_from_checkpoint_multi_device():
# so we need to use a fixed shared directory
tmp_path = Path("/tmp/test_load_from_checkpoint_multi_device")
n, g = 4, 3
var_names_g = [f"gene_{i}" for i in range(g)]
var_names_g = np.array([f"gene_{i}" for i in range(g)])
devices = int(os.environ.get("TEST_DEVICES", "1"))
# dataloader
train_loader = torch.utils.data.DataLoader(
BoringDataset(
np.random.randn(n, g),
np.array(var_names_g),
var_names_g,
),
collate_fn=collate_fn,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def log_normalize(x_ng: torch.Tensor):
y_ng = torch.log1p(target_count * x_ng / l_n1)
mean_g = y_ng.mean(dim=0)
std_g = y_ng.std(dim=0)
var_names_g = [f"gene_{i}" for i in range(g)]
var_names_g = np.array([f"gene_{i}" for i in range(g)])
transform = CellariumPipeline(
[
NormalizeTotal(target_count),
Expand Down

0 comments on commit e8163ee

Please sign in to comment.