Skip to content

Commit

Permalink
fix import
Browse files Browse the repository at this point in the history
  • Loading branch information
gmauro committed Nov 19, 2024
1 parent 9745cb3 commit e140ea8
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 72 deletions.
18 changes: 9 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ repos:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.3
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
# - repo: https://github.com/astral-sh/ruff-pre-commit
# # Ruff version.
# rev: v0.6.3
# hooks:
# # Run the linter.
# - id: ruff
# args: [ --fix ]
# # Run the formatter.
# - id: ruff-format
60 changes: 29 additions & 31 deletions gwasstudio/cli/export.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import click
import cloup
from gwasstudio import logger
from methods.locus_breaker import locus_breaker
from utils import process_write_chunk
import tiledb
from scipy import stats
import pandas as pd
import polars as pl
import tiledb

from gwasstudio import logger
from gwasstudio.methods.locus_breaker import locus_breaker
from gwasstudio.utils import process_write_chunk

help_doc = """
Exports data from a TileDB dataset.
"""


@cloup.command("export", no_args_is_help=True, help=help_doc)
@cloup.option_group(
"TileDB mandatory options",
Expand All @@ -24,25 +24,18 @@
cloup.option("--locusbreaker", default=False, is_flag=True, help="Option to run locusbreaker"),
cloup.option("--pvalue-sig", default=False, help="P-value threshold to use for filtering the data"),
cloup.option("--pvalue-limit", default=5.0, help="P-value threshold for loci borders"),
cloup.option("--hole-size", default=250000, help="Minimum pair-base distance between SNPs in different loci (default: 250000)")
cloup.option(
"--hole-size",
default=250000,
help="Minimum pair-base distance between SNPs in different loci (default: 250000)",
),
)
@cloup.option_group(
"Options for filtering using a genomic regions or a list of SNPs ids",
cloup.option("--snp-list", default="None", help="A txt file with a column containing the SNP ids")
cloup.option("--snp-list", default="None", help="A txt file with a column containing the SNP ids"),
)
@click.pass_context
def export(
ctx,
uri,
trait_id_file,
output_path,
pvalue_sig,
pvalue_limit,
hole_size,
snp_list,
locusbreaker

):
def export(ctx, uri, trait_id_file, output_path, pvalue_sig, pvalue_limit, hole_size, snp_list, locusbreaker):
cfg = ctx.obj["cfg"]
tiledb_unified = tiledb.open(uri, mode="r")
logger.info("TileDB dataset loaded")
Expand All @@ -51,31 +44,36 @@ def export(
if locusbreaker:
print("running locus breaker")
for trait in trait_id_list:
subset_SNPs_pd = tiledb_unified.query(dims=['CHR', 'POS', 'TRAITID'], attrs=['SNPID', 'ALLELE0', 'ALLELE1', 'BETA', 'SE', 'EAF', "MLOG10P"]).df[:, :, trait_id_list]
subset_SNPs_pd = tiledb_unified.query(
dims=["CHR", "POS", "TRAITID"], attrs=["SNPID", "ALLELE0", "ALLELE1", "BETA", "SE", "EAF", "MLOG10P"]
).df[:, :, trait_id_list]
results_lb = locus_breaker(subset_SNPs_pd)
logger.info(f"Saving locus-breaker output in {output_path}")
results_lb.to_csv(f,"{output_path}_{trait}.csv", index = False)
results_lb.to_csv(f, "{output_path}_{trait}.csv", index=False)
return

# If snp_list is selected, run extract_snp
if snp_list != "None":
SNP_list = pd.read_csv(snp_list, dtype = {"CHR":str, "POS":int, "ALLELE0":str, "ALLELE1":str})
chromosome_dict = SNP_list.groupby('CHR')['POS'].apply(list).to_dict()
SNP_list = pd.read_csv(snp_list, dtype={"CHR": str, "POS": int, "ALLELE0": str, "ALLELE1": str})
chromosome_dict = SNP_list.groupby("CHR")["POS"].apply(list).to_dict()
unique_positions = list(set(pos for positions in chromosome_dict.values() for pos in positions))
with tiledb.open("/scratch/bruno.ariano/tiledb_decode_unified", "r") as A:
tiledb_iterator = A.query(
return_incomplete=True
).df[:, unique_positions, trait_id_list] # Replace with appropriate filters if necessary
tiledb_iterator = A.query(return_incomplete=True).df[
:, unique_positions, trait_id_list
] # Replace with appropriate filters if necessary
with open(output_path, mode="a") as f:
for chunk in tiledb_iterator:
# Convert the chunk to Polars format
process_write_chunk(chunk, SNP_list, f)

logger.info(f"Saved filtered summary statistics by SNPs in {output_path}")
exit()

if pvalue_sig:
subset_SNPs = tiledb_unified.query(cond=f"MLOGP10 > {pvalue_sig}", dims=['CHR','POS','TRAITID'], attrs=['SNPID','ALLELE0','ALLELE1','BETA', 'SE', 'EAF',"MLOG10P"]).df[:, trait_id_list, :]
subset_SNPs.to_csv(output_path, index = False)
subset_SNPs = tiledb_unified.query(
cond=f"MLOGP10 > {pvalue_sig}",
dims=["CHR", "POS", "TRAITID"],
attrs=["SNPID", "ALLELE0", "ALLELE1", "BETA", "SE", "EAF", "MLOG10P"],
).df[:, trait_id_list, :]
subset_SNPs.to_csv(output_path, index=False)
logger.info(f"Saving filtered GWAS by regions and samples in {output_path}")

70 changes: 38 additions & 32 deletions gwasstudio/cli/ingest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os

import click
import cloup
import pathlib
from utils import process_and_ingest
from utils import create_tiledb_schema
import pandas as pd
import tiledb

from gwasstudio.utils import create_tiledb_schema
from gwasstudio.utils import process_and_ingest

help_doc = """
Ingest data data in a TileDB-unified dataset.
Expand All @@ -28,41 +32,43 @@
"--restart",
default=False,
help="Restart the ingestion from a previous run.",
)
),
)
@cloup.option_group(
"TileDB configurations",
cloup.option("--batch-size", default=5, help="The number of files to ingest in parallel."),
)
@click.pass_context
def ingest(ctx, input_path,checksum_path, attrs, uri, mem_budget_mb, threads, batch_size, restart):
# Parse checksum for mapping ids to files
checksum = pd.read_csv(file_path + "checksum.txt", sep = "\t", header=None)
checksum.columns = ["hash","filename"]
checksum_dict = pd.Series(checksum.hash.values,index=checksum.filename).to_dict()
def ingest(ctx, input_path, checksum_path, attrs, uri, mem_budget_mb, threads, batch_size, restart):
# Parse checksum for mapping ids to files
checksum = pd.read_csv(file_path + "checksum.txt", sep="\t", header=None)
checksum.columns = ["hash", "filename"]
checksum_dict = pd.Series(checksum.hash.values, index=checksum.filename).to_dict()

# Getting the file list and iterate through it using Dask
cfg = ctx.obj["cfg"]
if restart:
test_tiledb = tiledb.open(uri, "r")
arrow_table = test_tiledb.query(return_arrow=True, dims=['TRAITID'], attrs=[]).df[1, 1:10000000, :]
unique_arrow = (np.unique(arrow_table))
checksum_dict = pd.Series(checksum.filename.values,index=checksum.hash).to_dict()
file_list = []
checksum_dict_keys = checksum_dict.keys()
for record in checksum_dict_keys:
if record not in unique_arrow:
file_list.append(checksum_dict[record])
# Getting the file list and iterate through it using Dask
cfg = ctx.obj["cfg"]
if restart:
test_tiledb = tiledb.open(uri, "r")
arrow_table = test_tiledb.query(return_arrow=True, dims=["TRAITID"], attrs=[]).df[1, 1:10000000, :]
unique_arrow = np.unique(arrow_table)
checksum_dict = pd.Series(checksum.filename.values, index=checksum.hash).to_dict()
file_list = []
checksum_dict_keys = checksum_dict.keys()
for record in checksum_dict_keys:
if record not in unique_arrow:
file_list.append(checksum_dict[record])

# Process files in batches
else:
create_tiledb_schema(uri, cfg)
file_list = os.listdir(file_path)

for i in range(0, len(file_list), batch_size):
batch_files = file_list[i:i + batch_size]
tasks = [dask.delayed(process_and_ingest)(file_path + file, uri, checksum_dict, dict_type, cfg) for file in batch_files]
# Submit tasks and wait for completion
dask.compute(*tasks)
logging.info(f"Batch {i // batch_size + 1} completed.", flush=True)
# Process files in batches
else:
create_tiledb_schema(uri, cfg)
file_list = os.listdir(file_path)

for i in range(0, len(file_list), batch_size):
batch_files = file_list[i : i + batch_size]
tasks = [
dask.delayed(process_and_ingest)(file_path + file, uri, checksum_dict, dict_type, cfg)
for file in batch_files
]
# Submit tasks and wait for completion
dask.compute(*tasks)
logging.info(f"Batch {i // batch_size + 1} completed.", flush=True)

0 comments on commit e140ea8

Please sign in to comment.