Skip to content

Commit

Permalink
Merge pull request #10 from mvinyard/bug-fixes
Browse files Browse the repository at this point in the history
Bug fixes
  • Loading branch information
mvinyard authored May 5, 2023
2 parents 471defb + 2e15c92 commit 5dc9243
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 12 deletions.
12 changes: 6 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


# -- specify package version: --------------------------------------------------
__version__ = "0.0.20"
__version__ = "0.0.21"


# -- import packages: ----------------------------------------------------------
Expand All @@ -19,8 +19,8 @@
# -- run setup: ----------------------------------------------------------------
setuptools.setup(
name="torch-adata",
version="0.0.20",
python_requires=">3.7.0",
version="0.0.21",
python_requires=">3.9.0",
author="Michael E. Vinyard",
author_email="mvinyard@g.harvard.edu",
url=None,
Expand All @@ -29,10 +29,10 @@
description="torch-adata",
packages=setuptools.find_packages(),
install_requires=[
"anndata>=0.8",
"anndata>=0.9.1",
"licorice_font>=0.0.3",
"pytorch-lightning>=1.7.7",
"torch>=1.12",
"lightning>=2.0.1",
"torch>=2.0",
],
classifiers=[
"Development Status :: 2 - Pre-Alpha",
Expand Down
2 changes: 1 addition & 1 deletion torch_adata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


# -- specify package version: --------------------------------------------------
__version__ = "0.0.20"
__version__ = "0.0.21"


# -- import modules: -----------------------------------------------------------
Expand Down
9 changes: 8 additions & 1 deletion torch_adata/_core/_core_ancilliary/_identity_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@ def _groupby_msg(dataset):
if not dataset._grouped_by:
return
msg_g1 = lf.font_format("Grouped by", ["BOLD"])
msg_g2 = lf.font_format(dataset._grouped_by, ["RED"])
if isinstance(dataset._grouped_by, list):
sub_g1 = []
for gb in dataset._grouped_by:
sub_g1.append(lf.font_format(gb, ["RED"]))
msg_g2 = "-".join(sub_g1)
else:
msg_g2 = lf.font_format(dataset._grouped_by, ["RED"])

return "{}: '{}' with attributes:".format(msg_g1, msg_g2)

def annotate_attr_size(dataset, attr_name_set):
Expand Down
5 changes: 3 additions & 2 deletions torch_adata/_core/_lightning/_lightning_anndata_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


# -- import packages: --------------------------------------------------------------------
from pytorch_lightning import LightningDataModule
from lightning import LightningDataModule
from licorice_font import font_format
import pandas as pd
import anndata
Expand All @@ -35,7 +35,8 @@ def __init__(
train_val_split=[0.8, 0.2],
n_predict=2000,
use_key="X_pca",
groupby="Time point", # TODO: make optional
obs_keys=[],
groupby=None,
train_key="train",
val_key="val",
test_key="test",
Expand Down
3 changes: 2 additions & 1 deletion torch_adata/_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
from ._idx import idx
from ._dummy_batch import dummy_batch
from ._base_lightning_data_module import BaseLightningDataModule
from ._anndataset_split import AnnDatasetSplit
from ._anndataset_split import AnnDatasetSplit
from ._fetch_data import fetch
2 changes: 1 addition & 1 deletion torch_adata/_tools/_base_lightning_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


# -- import packages: ----------------------------------------------------------
from pytorch_lightning import LightningDataModule
from lightning import LightningDataModule
from torch.utils.data import DataLoader
import anndata

Expand Down
22 changes: 22 additions & 0 deletions torch_adata/_tools/_fetch_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

__module_name__ = "_fetch_data.py"
__doc__ = """Data fetching module. Key module called within torch_adata.AnnDataset.__init__()."""
__author__ = ", ".join(["Michael E. Vinyard"])
__email__ = ", ".join(["vinyard@g.harvard.edu"])


# -- import packages: --------------------------------------------------------------------
import torch
import anndata


# -- import local dependencies: ----------------------------------------------------------
from .._core._core_ancilliary._fetch_data import Fetch


# -- fetch X from adata: -----------------------------------------------------------------
def fetch(adata: anndata.AnnData, use_key: str)->torch.Tensor:

f = Fetch(adata)
return f.X(use_key)

0 comments on commit 5dc9243

Please sign in to comment.