Skip to content
247 changes: 241 additions & 6 deletions src/scanpy/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path, PurePath
from typing import TYPE_CHECKING

import anndata
import anndata.utils
import h5py
import numpy as np
Expand All @@ -32,6 +33,12 @@
read_mtx,
read_text,
)
import multiprocessing as mp
import threading
from dataclasses import dataclass

import numba
import scipy
from anndata import AnnData
from matplotlib.image import imread

Expand All @@ -46,6 +53,14 @@

from ._utils import Empty

indices_type = np.int64
indices_shm_type = "l"

semDataLoaded = None # will be initialized later
semDataCopied = None # will be initialized later

thread_workload = 4000000 # experimented value

# .gz and .bz2 suffixes are also allowed for text formats
text_exts = {
"csv",
Expand All @@ -67,6 +82,224 @@
"""Available file formats for reading data. """


def get_1d_index(row: int, col: int, num_cols: int) -> int:
"""
Convert 2D coordinates to 1D index.

Parameters:
row (int): Row index in the 2D array.
col (int): Column index in the 2D array.
num_cols (int): Number of columns in the 2D array.

Returns:
int: Corresponding 1D index.
"""
return row * num_cols + col

Check warning on line 97 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L97

Added line #L97 was not covered by tests


@dataclass
class LoadHelperData:
i: int
k: int
datalen: int
dataArray: mp.Array
indicesArray: mp.Array
startsArray: mp.Array
endsArray: mp.Array


def _load_helper(fname: str, helper_data: LoadHelperData):
i = helper_data.i
k = helper_data.k
datalen = helper_data.datalen
dataArray = helper_data.dataArray
indicesArray = helper_data.indicesArray
startsArray = helper_data.startsArray
endsArray = helper_data.endsArray

Check warning on line 118 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L112-L118

Added lines #L112 - L118 were not covered by tests

f = h5py.File(fname, "r")
dataA = np.frombuffer(dataArray, dtype=np.float32)
indicesA = np.frombuffer(indicesArray, dtype=indices_type)
startsA = np.frombuffer(startsArray, dtype=np.int64)
endsA = np.frombuffer(endsArray, dtype=np.int64)
for j in range(datalen // (k * thread_workload) + 1):

Check warning on line 125 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L120-L125

Added lines #L120 - L125 were not covered by tests
# compute start, end
s = i * datalen // k + j * thread_workload
e = min(s + thread_workload, (i + 1) * datalen // k)
length = e - s
startsA[i] = s
endsA[i] = e

Check warning on line 131 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L127-L131

Added lines #L127 - L131 were not covered by tests
# read direct
f["X"]["data"].read_direct(

Check warning on line 133 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L133

Added line #L133 was not covered by tests
dataA, np.s_[s:e], np.s_[i * thread_workload : i * thread_workload + length]
)
f["X"]["indices"].read_direct(

Check warning on line 136 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L136

Added line #L136 was not covered by tests
indicesA,
np.s_[s:e],
np.s_[i * thread_workload : i * thread_workload + length],
)

# coordinate with copy threads
semDataLoaded[i].release() # done data load
semDataCopied[i].acquire() # wait until data copied

Check warning on line 144 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L143-L144

Added lines #L143 - L144 were not covered by tests


def _waitload(i):
semDataLoaded[i].acquire()

Check warning on line 148 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L148

Added line #L148 was not covered by tests


def _signalcopy(i):
semDataCopied[i].release()

Check warning on line 152 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L152

Added line #L152 was not covered by tests


@dataclass
class CopyData:
data: np.ndarray
dataA: np.ndarray
indices: np.ndarray
indicesA: np.ndarray
startsA: np.ndarray
endsA: np.ndarray


def _fast_copy(copy_data: CopyData, k: int, m: int):
# Access the arrays through copy_data
data = copy_data.data
dataA = copy_data.dataA
indices = copy_data.indices
indicesA = copy_data.indicesA
starts = copy_data.startsA
ends = copy_data.endsA

Check warning on line 172 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L167-L172

Added lines #L167 - L172 were not covered by tests

def thread_fun(i, m):
for j in range(m):
with numba.objmode():
_waitload(i)
length = ends[i] - starts[i]
data[starts[i] : ends[i]] = dataA[

Check warning on line 179 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L174-L179

Added lines #L174 - L179 were not covered by tests
i * thread_workload : i * thread_workload + length
]
indices[starts[i] : ends[i]] = indicesA[

Check warning on line 182 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L182

Added line #L182 was not covered by tests
i * thread_workload : i * thread_workload + length
]
with numba.objmode():
_signalcopy(i)

Check warning on line 186 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L185-L186

Added lines #L185 - L186 were not covered by tests

threads = [threading.Thread(target=thread_fun, args=(i, m)) for i in range(k)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()

Check warning on line 192 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L188-L192

Added lines #L188 - L192 were not covered by tests


def fastload(fname, backed):
f = h5py.File(fname, backed)
assert "X" in f, "'X' is missing from f"
assert "var" in f, "'var' is missing from f"
assert "obs" in f, "'obs' is missing from f"

if not scipy.sparse.issparse(f["X"]):
f.close()
return read_h5ad(fname, backed=backed)

# get obs dataframe
rows = f["obs"][list(f["obs"].keys())[0]].size

Check warning on line 206 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L206

Added line #L206 was not covered by tests
# load index pointers, prepare shared arrays
indptr = f["X"]["indptr"][0 : rows + 1]
datalen = int(indptr[-1])

Check warning on line 209 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L208-L209

Added lines #L208 - L209 were not covered by tests

if datalen < thread_workload:
f.close()
return read_h5ad(fname, backed=backed)
if "_index" in f["obs"]:
dfobsind = pd.Series(f["obs"]["_index"].asstr()[0:rows])
dfobs = pd.DataFrame(index=dfobsind)

Check warning on line 216 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L211-L216

Added lines #L211 - L216 were not covered by tests
else:
dfobs = pd.DataFrame()
for k in f["obs"]:
if k == "_index":
continue
dfobs[k] = f["obs"][k].asstr()[...]

Check warning on line 222 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L218-L222

Added lines #L218 - L222 were not covered by tests

# get var dataframe
if "_index" in f["var"]:
dfvarind = pd.Series(f["var"]["_index"].asstr()[...])
dfvar = pd.DataFrame(index=dfvarind)

Check warning on line 227 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L225-L227

Added lines #L225 - L227 were not covered by tests
else:
dfvar = pd.DataFrame()
for k in f["var"]:
if k == "_index":
continue
dfvar[k] = f["var"][k].asstr()[...]

Check warning on line 233 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L229-L233

Added lines #L229 - L233 were not covered by tests

f.close()
k = numba.get_num_threads()
dataArray = mp.Array(

Check warning on line 237 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L235-L237

Added lines #L235 - L237 were not covered by tests
"f", k * thread_workload, lock=False
) # should be in shared memory
indicesArray = mp.Array(

Check warning on line 240 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L240

Added line #L240 was not covered by tests
indices_shm_type, k * thread_workload, lock=False
) # should be in shared memory
startsArray = mp.Array("l", k, lock=False) # start index of data read
endsArray = mp.Array("l", k, lock=False) # end index (noninclusive) of data read

Check warning on line 244 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L243-L244

Added lines #L243 - L244 were not covered by tests
global semDataLoaded
global semDataCopied
semDataLoaded = [mp.Semaphore(0) for _ in range(k)]
semDataCopied = [mp.Semaphore(0) for _ in range(k)]
dataA = np.frombuffer(dataArray, dtype=np.float32)
indicesA = np.frombuffer(indicesArray, dtype=indices_type)
startsA = np.frombuffer(startsArray, dtype=np.int64)
endsA = np.frombuffer(endsArray, dtype=np.int64)
data = np.empty(datalen, dtype=np.float32)
indices = np.empty(datalen, dtype=indices_type)

Check warning on line 254 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L247-L254

Added lines #L247 - L254 were not covered by tests

procs = [

Check warning on line 256 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L256

Added line #L256 was not covered by tests
mp.Process(
target=_load_helper,
args=(
fname,
LoadHelperData(
i=i,
k=k,
datalen=datalen,
dataArray=dataArray,
indicesArray=indicesArray,
startsArray=startsArray,
endsArray=endsArray,
),
),
)
for i in range(k)
]

for p in procs:
p.start()

Check warning on line 276 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L275-L276

Added lines #L275 - L276 were not covered by tests

copy_data = CopyData(

Check warning on line 278 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L278

Added line #L278 was not covered by tests
data=data,
dataA=dataA,
indices=indices,
indicesA=indicesA,
startsA=startsA,
endsA=endsA,
)

_fast_copy(copy_data, k, datalen // (k * thread_workload) + 1)

Check warning on line 287 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L287

Added line #L287 was not covered by tests

for p in procs:
p.join()

Check warning on line 290 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L289-L290

Added lines #L289 - L290 were not covered by tests

X = scipy.sparse.csr_matrix((0, 0))
X.data = data
X.indices = indices
X.indptr = indptr
X._shape = (rows, dfvar.shape[0])

Check warning on line 296 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L292-L296

Added lines #L292 - L296 were not covered by tests

# create AnnData
adata = anndata.AnnData(X, dfobs, dfvar)
return adata

Check warning on line 300 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L299-L300

Added lines #L299 - L300 were not covered by tests


# --------------------------------------------------------------------------------
# Reading and Writing data files and AnnData objects
# --------------------------------------------------------------------------------
Expand All @@ -83,7 +316,7 @@
)
def read(
filename: Path | str,
backed: Literal["r", "r+"] | None = None,
backed: Literal["r", "r+"] | None = "r+",
*,
sheet: str | None = None,
ext: str | None = None,
Expand Down Expand Up @@ -164,7 +397,7 @@
"or pass the parameter `ext`."
)
raise ValueError(msg)
return read_h5ad(filename, backed=backed)
return fastload(filename, backed=backed)

Check warning on line 400 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L400

Added line #L400 was not covered by tests


@old_positionals("genome", "gex_only", "backup_url")
Expand Down Expand Up @@ -346,7 +579,9 @@
(
feature_metadata_name,
dsets[feature_metadata_name].astype(
bool if feature_metadata_item.dtype.kind == "b" else str
bool
if feature_metadata_item.dtype.kind == "thread_workload"
else str
),
)
for feature_metadata_name, feature_metadata_item in f["matrix"][
Expand Down Expand Up @@ -791,7 +1026,7 @@
# read hdf5 files
if ext in {"h5", "h5ad"}:
if sheet is None:
return read_h5ad(filename, backed=backed)
return fastload(filename, backed)
else:
logg.debug(f"reading sheet {sheet} from file {filename}")
return read_hdf(filename, sheet)
Expand All @@ -803,7 +1038,7 @@
path_cache = path_cache.with_suffix("")
if cache and path_cache.is_file():
logg.info(f"... reading from cache file {path_cache}")
return read_h5ad(path_cache)
return fastload(path_cache, backed)

Check warning on line 1041 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L1041

Added line #L1041 was not covered by tests

if not is_present:
msg = f"Did not find file {filename}."
Expand Down Expand Up @@ -1044,7 +1279,7 @@
total = resp.info().get("content-length", None)
with (
tqdm(
unit="B",
unit="thread_workload",
unit_scale=True,
miniters=1,
unit_divisor=1024,
Expand Down
Loading