Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
tszfungc committed Sep 6, 2023
1 parent 3fdc211 commit cb2c0bd
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/hamsta/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def fit(
# apply filter
rotated_Z = rotated_Z[S_filter]
intercept_design = intercept_design[S_filter, :]
intercept_design = intercept_design[: , jnp.sum(intercept_design, axis=0)>0]
intercept_design = intercept_design[:, jnp.sum(intercept_design, axis=0) > 0]

# group intercept into multiple var components
# bin_idx = np.arange(S.shape[0]) // self.intercept_blksize
Expand Down Expand Up @@ -372,7 +372,7 @@ def _jackknife(
S=S[selected_index],
jackknife=False,
intercept_design=intercept_design[selected_index],
complete=False
complete=False,
)
pseudo_val = (
num_blocks * param_full
Expand Down
13 changes: 8 additions & 5 deletions src/hamsta/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@


def read_bed(fname: str):
bedf = open(fname, 'r')
bedf = open(fname, "r")
return [list(map(int, line.split("\t"))) for line in bedf.read().splitlines()]


def read_singular_val(svdprefix, svdprefix_chr, nS):
if svdprefix is not None:
S_ = np.load(f"{svdprefix}.SVD.S.npy")[:nS]
Expand Down Expand Up @@ -244,18 +245,20 @@ def read_zarr(
def read_nc(
fname: str,
ancestry: str,
exclude: str=None,
exclude: str = None,
) -> Tuple[jnp.ndarray, pd.DataFrame]:

ds = xr.open_dataset(fname).load()

#LA_matrix = jnp.array(ds.locanc.sel(ancestry=ancestry).sum(dim="ploidy"))
# LA_matrix = jnp.array(ds.locanc.sel(ancestry=ancestry).sum(dim="ploidy"))
ds_LA = ds.locanc.sel(ancestry=ancestry).sum(dim="ploidy")
if exclude is not None:
exclude_region = read_bed(exclude)
extract = np.logical_and.reduce(
[~np.logical_and(start <= ds_LA['marker'], ds_LA['marker'] < end)
for _, start, end in exclude_region]
[
~np.logical_and(start <= ds_LA["marker"], ds_LA["marker"] < end)
for _, start, end in exclude_region
]
)
ds_LA = ds_LA[extract, :]

Expand Down
5 changes: 2 additions & 3 deletions src/hamsta/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def SVD(
number of components computed in truncated
Returns:
``(U, S)`` in SVD of ``X = U * S @ Vh``, where X is A/sqrt(N) with A standardized
``(U, S)`` in SVD of ``X = U * S @ Vh``, where X is A/sqrt(N) with A standardized # noqa: E501
"""

if LAD is not None:
A_std = LAD
else:
Expand All @@ -59,7 +59,6 @@ def SVD(
else:
U, S, _ = randomized_svd(A_std, n_components=k, random_state=None)


if LAD is not None:
S = jnp.sqrt(S)
else:
Expand Down

0 comments on commit cb2c0bd

Please sign in to comment.