Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

faster reading of h5ad file (~18X faster) #3365

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 237 additions & 6 deletions src/scanpy/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path, PurePath
from typing import TYPE_CHECKING

import anndata
import anndata.utils
import h5py
import numpy as np
Expand All @@ -32,6 +33,12 @@
read_mtx,
read_text,
)
import multiprocessing as mp
import threading
from dataclasses import dataclass

import numba
import scipy
from anndata import AnnData
from matplotlib.image import imread

Expand All @@ -45,6 +52,14 @@

from ._utils import Empty

indices_type = np.int64
indices_shm_type = "l"

semDataLoaded = None # will be initialized later
semDataCopied = None # will be initialized later

thread_workload = 4000000 # experimented value

# .gz and .bz2 suffixes are also allowed for text formats
text_exts = {
"csv",
Expand All @@ -66,6 +81,220 @@
"""Available file formats for reading data. """


def get_1d_index(row: int, col: int, num_cols: int) -> int:
"""
Convert 2D coordinates to 1D index.

Parameters:
row (int): Row index in the 2D array.
col (int): Column index in the 2D array.
num_cols (int): Number of columns in the 2D array.

Returns:
int: Corresponding 1D index.
"""
return row * num_cols + col

Check warning on line 96 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L96

Added line #L96 was not covered by tests


@dataclass
class LoadHelperData:
i: int
k: int
datalen: int
dataArray: mp.Array
indicesArray: mp.Array
startsArray: mp.Array
endsArray: mp.Array


def _load_helper(fname: str, helper_data: LoadHelperData):
i = helper_data.i
k = helper_data.k
datalen = helper_data.datalen
dataArray = helper_data.dataArray
indicesArray = helper_data.indicesArray
startsArray = helper_data.startsArray
endsArray = helper_data.endsArray

Check warning on line 117 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L111-L117

Added lines #L111 - L117 were not covered by tests

f = h5py.File(fname, "r")
dataA = np.frombuffer(dataArray, dtype=np.float32)
indicesA = np.frombuffer(indicesArray, dtype=indices_type)
startsA = np.frombuffer(startsArray, dtype=np.int64)
endsA = np.frombuffer(endsArray, dtype=np.int64)
for j in range(datalen // (k * thread_workload) + 1):

Check warning on line 124 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L119-L124

Added lines #L119 - L124 were not covered by tests
# compute start, end
s = i * datalen // k + j * thread_workload
e = min(s + thread_workload, (i + 1) * datalen // k)
length = e - s
startsA[i] = s
endsA[i] = e

Check warning on line 130 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L126-L130

Added lines #L126 - L130 were not covered by tests
# read direct
f["X"]["data"].read_direct(

Check warning on line 132 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L132

Added line #L132 was not covered by tests
dataA, np.s_[s:e], np.s_[i * thread_workload : i * thread_workload + length]
)
f["X"]["indices"].read_direct(

Check warning on line 135 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L135

Added line #L135 was not covered by tests
indicesA,
np.s_[s:e],
np.s_[i * thread_workload : i * thread_workload + length],
)

# coordinate with copy threads
semDataLoaded[i].release() # done data load
semDataCopied[i].acquire() # wait until data copied

Check warning on line 143 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L142-L143

Added lines #L142 - L143 were not covered by tests


def _waitload(i):
semDataLoaded[i].acquire()

Check warning on line 147 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L147

Added line #L147 was not covered by tests


def _signalcopy(i):
semDataCopied[i].release()

Check warning on line 151 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L151

Added line #L151 was not covered by tests


@dataclass
class CopyData:
data: np.ndarray
dataA: np.ndarray
indices: np.ndarray
indicesA: np.ndarray
startsA: np.ndarray
endsA: np.ndarray


def _fast_copy(copy_data: CopyData, k: int, m: int):
# Access the arrays through copy_data
data = copy_data.data
dataA = copy_data.dataA
indices = copy_data.indices
indicesA = copy_data.indicesA
starts = copy_data.startsA
ends = copy_data.endsA

Check warning on line 171 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L166-L171

Added lines #L166 - L171 were not covered by tests

def thread_fun(i, m):
for j in range(m):
with numba.objmode():
_waitload(i)
length = ends[i] - starts[i]
data[starts[i] : ends[i]] = dataA[

Check warning on line 178 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L173-L178

Added lines #L173 - L178 were not covered by tests
i * thread_workload : i * thread_workload + length
]
indices[starts[i] : ends[i]] = indicesA[

Check warning on line 181 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L181

Added line #L181 was not covered by tests
i * thread_workload : i * thread_workload + length
]
with numba.objmode():
_signalcopy(i)

Check warning on line 185 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L184-L185

Added lines #L184 - L185 were not covered by tests

threads = [threading.Thread(target=thread_fun, args=(i, m)) for i in range(k)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()

Check warning on line 191 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L187-L191

Added lines #L187 - L191 were not covered by tests


def fastload(fname, backed):
f = h5py.File(fname, backed)
assert "X" in f, "'X' is missing from f"
assert "var" in f, "'var' is missing from f"
assert "obs" in f, "'obs' is missing from f"

# get obs dataframe
rows = f["obs"][list(f["obs"].keys())[0]].size
# load index pointers, prepare shared arrays
indptr = f["X"]["indptr"][0 : rows + 1]
datalen = int(indptr[-1])

Check warning on line 204 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L203-L204

Added lines #L203 - L204 were not covered by tests

if datalen < thread_workload:
f.close()
return read_h5ad(fname, backed=backed)
if "_index" in f["obs"]:
dfobsind = pd.Series(f["obs"]["_index"].asstr()[0:rows])
dfobs = pd.DataFrame(index=dfobsind)

Check warning on line 211 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L206-L211

Added lines #L206 - L211 were not covered by tests
else:
dfobs = pd.DataFrame()
for k in f["obs"]:
if k == "_index":
continue
dfobs[k] = f["obs"][k].asstr()[...]

Check warning on line 217 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L213-L217

Added lines #L213 - L217 were not covered by tests

# get var dataframe
if "_index" in f["var"]:
dfvarind = pd.Series(f["var"]["_index"].asstr()[...])
dfvar = pd.DataFrame(index=dfvarind)

Check warning on line 222 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L220-L222

Added lines #L220 - L222 were not covered by tests
else:
dfvar = pd.DataFrame()
for k in f["var"]:
if k == "_index":
continue
dfvar[k] = f["var"][k].asstr()[...]

Check warning on line 228 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L224-L228

Added lines #L224 - L228 were not covered by tests

f.close()
k = numba.get_num_threads()
dataArray = mp.Array(

Check warning on line 232 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L230-L232

Added lines #L230 - L232 were not covered by tests
"f", k * thread_workload, lock=False
) # should be in shared memory
indicesArray = mp.Array(

Check warning on line 235 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L235

Added line #L235 was not covered by tests
indices_shm_type, k * thread_workload, lock=False
) # should be in shared memory
startsArray = mp.Array("l", k, lock=False) # start index of data read
endsArray = mp.Array("l", k, lock=False) # end index (noninclusive) of data read

Check warning on line 239 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L238-L239

Added lines #L238 - L239 were not covered by tests
global semDataLoaded
global semDataCopied
semDataLoaded = [mp.Semaphore(0) for _ in range(k)]
semDataCopied = [mp.Semaphore(0) for _ in range(k)]
dataA = np.frombuffer(dataArray, dtype=np.float32)
indicesA = np.frombuffer(indicesArray, dtype=indices_type)
startsA = np.frombuffer(startsArray, dtype=np.int64)
endsA = np.frombuffer(endsArray, dtype=np.int64)
data = np.empty(datalen, dtype=np.float32)
indices = np.empty(datalen, dtype=indices_type)

Check warning on line 249 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L242-L249

Added lines #L242 - L249 were not covered by tests

procs = [

Check warning on line 251 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L251

Added line #L251 was not covered by tests
mp.Process(
target=_load_helper,
args=(
fname,
LoadHelperData(
i=i,
k=k,
datalen=datalen,
dataArray=dataArray,
indicesArray=indicesArray,
startsArray=startsArray,
endsArray=endsArray,
),
),
)
for i in range(k)
]

for p in procs:
p.start()

Check warning on line 271 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L270-L271

Added lines #L270 - L271 were not covered by tests

copy_data = CopyData(

Check warning on line 273 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L273

Added line #L273 was not covered by tests
data=data,
dataA=dataA,
indices=indices,
indicesA=indicesA,
startsA=startsA,
endsA=endsA,
)

_fast_copy(copy_data, k, datalen // (k * thread_workload) + 1)

Check warning on line 282 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L282

Added line #L282 was not covered by tests

for p in procs:
p.join()

Check warning on line 285 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L284-L285

Added lines #L284 - L285 were not covered by tests

X = scipy.sparse.csr_matrix((0, 0))
X.data = data
X.indices = indices
X.indptr = indptr
X._shape = (rows, dfvar.shape[0])

Check warning on line 291 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L287-L291

Added lines #L287 - L291 were not covered by tests

# create AnnData
adata = anndata.AnnData(X, dfobs, dfvar)
return adata

Check warning on line 295 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L294-L295

Added lines #L294 - L295 were not covered by tests


# --------------------------------------------------------------------------------
# Reading and Writing data files and AnnData objects
# --------------------------------------------------------------------------------
Expand All @@ -82,7 +311,7 @@
)
def read(
filename: Path | str,
backed: Literal["r", "r+"] | None = None,
backed: Literal["r", "r+"] | None = "r+",
*,
sheet: str | None = None,
ext: str | None = None,
Expand Down Expand Up @@ -162,7 +391,7 @@
f"ending on one of the available extensions {avail_exts} "
"or pass the parameter `ext`."
)
return read_h5ad(filename, backed=backed)
return fastload(filename, backed)

Check warning on line 394 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L394

Added line #L394 was not covered by tests


@old_positionals("genome", "gex_only", "backup_url")
Expand Down Expand Up @@ -337,7 +566,9 @@
(
feature_metadata_name,
dsets[feature_metadata_name].astype(
bool if feature_metadata_item.dtype.kind == "b" else str
bool
if feature_metadata_item.dtype.kind == "thread_workload"
else str
),
)
for feature_metadata_name, feature_metadata_item in f["matrix"][
Expand Down Expand Up @@ -778,7 +1009,7 @@
# read hdf5 files
if ext in {"h5", "h5ad"}:
if sheet is None:
return read_h5ad(filename, backed=backed)
return fastload(filename, backed)
else:
logg.debug(f"reading sheet {sheet} from file {filename}")
return read_hdf(filename, sheet)
Expand All @@ -790,7 +1021,7 @@
path_cache = path_cache.with_suffix("")
if cache and path_cache.is_file():
logg.info(f"... reading from cache file {path_cache}")
return read_h5ad(path_cache)
return fastload(path_cache, backed)

Check warning on line 1024 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L1024

Added line #L1024 was not covered by tests

if not is_present:
raise FileNotFoundError(f"Did not find file {filename}.")
Expand Down Expand Up @@ -1028,7 +1259,7 @@
total = resp.info().get("content-length", None)
with (
tqdm(
unit="B",
unit="thread_workload",
unit_scale=True,
miniters=1,
unit_divisor=1024,
Expand Down
Loading