Skip to content

Commit

Permalink
fix multiple errors, some code was only commented
Browse files Browse the repository at this point in the history
  • Loading branch information
gmauro committed Nov 19, 2024
1 parent e140ea8 commit d5ad3c4
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 108 deletions.
18 changes: 9 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ repos:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
# - repo: https://github.com/astral-sh/ruff-pre-commit
# # Ruff version.
# rev: v0.6.3
# hooks:
# # Run the linter.
# - id: ruff
# args: [ --fix ]
# # Run the formatter.
# - id: ruff-format
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.3
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
4 changes: 2 additions & 2 deletions gwasstudio/cli/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
@click.pass_context
def export(ctx, uri, trait_id_file, output_path, pvalue_sig, pvalue_limit, hole_size, snp_list, locusbreaker):
cfg = ctx.obj["cfg"]
# cfg = ctx.obj["cfg"]
tiledb_unified = tiledb.open(uri, mode="r")
logger.info("TileDB dataset loaded")
trait_id_list = open(trait_id_file, "r").read().rstrip().split("\n")
Expand All @@ -49,7 +49,7 @@ def export(ctx, uri, trait_id_file, output_path, pvalue_sig, pvalue_limit, hole_
).df[:, :, trait_id_list]
results_lb = locus_breaker(subset_SNPs_pd)
logger.info(f"Saving locus-breaker output in {output_path}")
results_lb.to_csv(f, "{output_path}_{trait}.csv", index=False)
results_lb.to_csv(f"{output_path}_{trait}.csv", index=False)
return

# If snp_list is selected, run extract_snp
Expand Down
28 changes: 16 additions & 12 deletions gwasstudio/cli/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import click
import cloup
import numpy as np
import pandas as pd
import tiledb

from gwasstudio.utils import create_tiledb_schema
from gwasstudio.utils import process_and_ingest

help_doc = """
Ingest data data in a TileDB-unified dataset.
Expand Down Expand Up @@ -41,7 +41,7 @@
@click.pass_context
def ingest(ctx, input_path, checksum_path, attrs, uri, mem_budget_mb, threads, batch_size, restart):
# Parse checksum for mapping ids to files
checksum = pd.read_csv(file_path + "checksum.txt", sep="\t", header=None)
checksum = pd.read_csv(input_path + "checksum.txt", sep="\t", header=None)
checksum.columns = ["hash", "filename"]
checksum_dict = pd.Series(checksum.hash.values, index=checksum.filename).to_dict()

Expand All @@ -61,14 +61,18 @@ def ingest(ctx, input_path, checksum_path, attrs, uri, mem_budget_mb, threads, b
# Process files in batches
else:
create_tiledb_schema(uri, cfg)
file_list = os.listdir(file_path)
file_list = os.listdir(input_path)

for i in range(0, len(file_list), batch_size):
batch_files = file_list[i : i + batch_size]
tasks = [
dask.delayed(process_and_ingest)(file_path + file, uri, checksum_dict, dict_type, cfg)
for file in batch_files
]
# Submit tasks and wait for completion
dask.compute(*tasks)
logging.info(f"Batch {i // batch_size + 1} completed.", flush=True)

# what is dict_type?
# for i in range(0, len(file_list), batch_size):
# batch_files = file_list[i : i + batch_size]
#
# tasks = [
# dask.delayed(process_and_ingest)(input_path + file, uri, checksum_dict, dict_type, cfg)
# for file in batch_files
# ]
# # Submit tasks and wait for completion
# dask.compute(*tasks)
# logger.info(f"Batch {i // batch_size + 1} completed.", flush=True)
#
41 changes: 29 additions & 12 deletions gwasstudio/functions/create_tiledb_schema.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,40 @@
import tiledb
import numpy as np
def create_tiledb_schema(uri,cfg):


def create_tiledb_schema(uri, cfg):
pos_domain = (1, 3000000000)
dom = tiledb.Domain(
tiledb.Dim(name="CHR", dtype=np.bytes_, var=False),
tiledb.Dim(name="POS", domain = pos_domain, dtype=np.uint32, var=False),
tiledb.Dim(name="TRAITID", dtype=np.bytes_, var=True))
tiledb.Dim(name="POS", domain=pos_domain, dtype=np.uint32, var=False),
tiledb.Dim(name="TRAITID", dtype=np.bytes_, var=True),
)
schema = tiledb.ArraySchema(
domain=dom,
sparse=True,
allows_duplicates=True,
attrs=[
tiledb.Attr(name="BETA", dtype=np.float64, var=False,filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)]) ),
tiledb.Attr(name="SE", dtype=np.float64, var=False,filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])),
tiledb.Attr(name="EAF", dtype=np.float64, var=False,filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])),
tiledb.Attr(name="MLOG10P", dtype=np.float64, var=False,filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])),
tiledb.Attr(name="ALLELE0", dtype=np.bytes_, var=True,filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])),
tiledb.Attr(name="ALLELE1", dtype=np.bytes_, var=True,filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])),
tiledb.Attr(name="SNPID", dtype=np.bytes_, var=True,filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)]))
]
tiledb.Attr(
name="BETA", dtype=np.float64, var=False, filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])
),
tiledb.Attr(
name="SE", dtype=np.float64, var=False, filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])
),
tiledb.Attr(
name="EAF", dtype=np.float64, var=False, filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])
),
tiledb.Attr(
name="MLOG10P", dtype=np.float64, var=False, filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])
),
tiledb.Attr(
name="ALLELE0", dtype=np.bytes_, var=True, filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])
),
tiledb.Attr(
name="ALLELE1", dtype=np.bytes_, var=True, filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])
),
tiledb.Attr(
name="SNPID", dtype=np.bytes_, var=True, filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])
),
],
)
tiledb.Array.create(uri, schema, ctx=cfg)
tiledb.Array.create(uri, schema, ctx=cfg)
30 changes: 15 additions & 15 deletions gwasstudio/functions/process_and_ingest.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
from utils import compute_sha256
import pandas as pd
import tiledb

from gwasstudio.utils import compute_sha256


def process_and_ingest(file_path, uri, checksum_dict, dict_type, renaming_columns, attributes_columns, cfg):
# Read file with Dask
df = pd.read_csv(
file_path,
compression="gzip",
sep="\t",
usecols = attributes_columns
#usecols=["Chrom", "Pos", "Name", "effectAllele", "Beta", "SE", "ImpMAF"]
usecols=attributes_columns,
# usecols=["Chrom", "Pos", "Name", "effectAllele", "Beta", "SE", "ImpMAF"]
)
sha256 = compute_sha256(file_path)
# Rename columns and modify 'chrom' field
df = df.rename(columns = renaming_columns)
df["chrom"] = df["chrom"].str.replace('chr', '')
df["chrom"] = df["chrom"].str.replace('X', '23')
df["chrom"] = df["chrom"].str.replace('Y', '24')
df = df.rename(columns=renaming_columns)
df["chrom"] = df["chrom"].str.replace("chr", "")
df["chrom"] = df["chrom"].str.replace("X", "23")
df["chrom"] = df["chrom"].str.replace("Y", "24")
# Add trait_id based on the checksum_dict
file_name = file_path.split('/')[-1]
# file_name = file_path.split("/")[-1]
df["trait_id"] = sha256

# Store the processed data in TileDB
tiledb.from_pandas(
uri=uri,
dataframe=df,
index_dims=["chrom", "pos", "trait_id"],
mode="append",
column_types=dict_type,
ctx = ctx
)
uri=uri, dataframe=df, index_dims=["chrom", "pos", "trait_id"], mode="append", column_types=dict_type, ctx=cfg
)
5 changes: 3 additions & 2 deletions gwasstudio/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from gwasstudio.cli.metadata.ingest import meta_ingest
from gwasstudio.cli.metadata.query import meta_query
from gwasstudio.cli.metadata.view import meta_view
from gwasstudio.cli.query import query

# from gwasstudio.cli.query import query
from gwasstudio.dask_client import DaskClient as Client


Expand Down Expand Up @@ -87,7 +88,7 @@ def cli_init(

def main():
cli_init.add_command(info)
cli_init.add_command(query)
# cli_init.add_command(query)
cli_init.add_command(export)
cli_init.add_command(ingest)
cli_init.add_command(meta_ingest)
Expand Down
28 changes: 12 additions & 16 deletions gwasstudio/methods/locus_breaker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import pandas as pd
from typing import List


def locus_breaker(
tiledb_results_pd,
pvalue_limit: float = 5,
pvalue_sig: float = 5,
hole_size: int = 250000
tiledb_results_pd, pvalue_limit: float = 5, pvalue_sig: float = 5, hole_size: int = 250000
) -> pd.DataFrame:
"""
Breaking genome in locus
Expand All @@ -18,21 +14,21 @@ def locus_breaker(
:return: DataFrame with the loci information
"""
expected_schema = {
'CHR': pd.Series(dtype='object'),
'POS': pd.Series(dtype='int64'),
'ALLELE0': pd.Series(dtype='object'),
'ALLELE1': pd.Series(dtype='object'),
'MLOG10P': pd.Series(dtype='float64'),
'BETA': pd.Series(dtype='object'),
'SE': pd.Series(dtype='object'),
'TRAITID': pd.Series(dtype='object')
}
"CHR": pd.Series(dtype="object"),
"POS": pd.Series(dtype="int64"),
"ALLELE0": pd.Series(dtype="object"),
"ALLELE1": pd.Series(dtype="object"),
"MLOG10P": pd.Series(dtype="float64"),
"BETA": pd.Series(dtype="object"),
"SE": pd.Series(dtype="object"),
"TRAITID": pd.Series(dtype="object"),
}

# Convert fmt_LP from list to float
if tiledb_results_pd.empty:
print("this region is empty")
return pd.DataFrame(expected_schema)

# Filter rows based on the p_limit threshold
tiledb_results_pd = tiledb_results_pd[tiledb_results_pd["MLOG10P"] > pvalue_limit]

Expand Down Expand Up @@ -74,4 +70,4 @@ def locus_breaker(
# Remove one of the duplicate 'contig' columns if present
trait_res_df = trait_res_df.loc[:, ~trait_res_df.columns.duplicated()]

return trait_res_df
return trait_res_df
82 changes: 43 additions & 39 deletions gwasstudio/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

import hashlib
import pathlib
import string
import random
import tiledb
import string

import numpy as np
import pandas as pd
import polars as pl
import tiledb

DEFAULT_BUFSIZE = 4096


Expand Down Expand Up @@ -99,7 +101,7 @@ def generate_random_word(length: int) -> str:


# Define the TileDB array schema with SNP, gene, and population dimensions
def create_tiledb_schema(uri: str,cfg: dict):
def create_tiledb_schema(uri: str, cfg: dict):
"""
Create an empty schema for TileDB.
Expand All @@ -111,25 +113,38 @@ def create_tiledb_schema(uri: str,cfg: dict):
chrom_domain = (1, 24)
pos_domain = (1, 3000000000)
dom = tiledb.Domain(
tiledb.Dim(name="chrom", domain = chrom_domain, dtype=np.uint8, var=False),
tiledb.Dim(name="pos", domain = pos_domain, dtype=np.uint32, var=False),
tiledb.Dim(name="trait_id", dtype=np.dtype('S64'), var=True)
tiledb.Dim(name="chrom", domain=chrom_domain, dtype=np.uint8, var=False),
tiledb.Dim(name="pos", domain=pos_domain, dtype=np.uint32, var=False),
tiledb.Dim(name="trait_id", dtype=np.dtype("S64"), var=True),
)
schema = tiledb.ArraySchema(
domain=dom,
sparse=True,
allows_duplicates=True,
attrs=[
tiledb.Attr(name="beta", dtype=np.float32, var=False,filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)]) ),
tiledb.Attr(name="se", dtype=np.float32, var=False,filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])),
tiledb.Attr(name="freq", dtype=np.float32, var=False,filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])),
tiledb.Attr(name="alt", dtype=np.dtype('S5'), var=True,filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])),
tiledb.Attr(name="SNP", dtype=np.dtype('S20'), var=True,filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)]))
]
tiledb.Attr(
name="beta", dtype=np.float32, var=False, filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])
),
tiledb.Attr(
name="se", dtype=np.float32, var=False, filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])
),
tiledb.Attr(
name="freq", dtype=np.float32, var=False, filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])
),
tiledb.Attr(
name="alt", dtype=np.dtype("S5"), var=True, filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])
),
tiledb.Attr(
name="SNP", dtype=np.dtype("S20"), var=True, filters=tiledb.FilterList([tiledb.ZstdFilter(level=5)])
),
],
)
tiledb.Array.create(uri, schema, ctx=cfg)

def process_and_ingest(file_path: str, uri, dict_type: dict, renaming_columns: dict, attributes_columns: list, cfg: dict):

def process_and_ingest(
file_path: str, uri, dict_type: dict, renaming_columns: dict, attributes_columns: list, cfg: dict
):
"""
Process a single file and ingest it in a TileDB
Expand All @@ -145,44 +160,33 @@ def process_and_ingest(file_path: str, uri, dict_type: dict, renaming_columns: d
file_path,
compression="gzip",
sep="\t",
usecols = attributes_columns
#usecols=["Chrom", "Pos", "Name", "effectAllele", "Beta", "SE", "ImpMAF"]
usecols=attributes_columns,
# usecols=["Chrom", "Pos", "Name", "effectAllele", "Beta", "SE", "ImpMAF"]
)
sha256 = compute_sha256(file_path)
# Rename columns and modify 'chrom' field
df = df.rename(columns = renaming_columns)
df["chrom"] = df["chrom"].str.replace('chr', '')
df["chrom"] = df["chrom"].str.replace('X', '23')
df["chrom"] = df["chrom"].str.replace('Y', '24')
df = df.rename(columns=renaming_columns)
df["chrom"] = df["chrom"].str.replace("chr", "")
df["chrom"] = df["chrom"].str.replace("X", "23")
df["chrom"] = df["chrom"].str.replace("Y", "24")
# Add trait_id based on the checksum_dict
file_name = file_path.split('/')[-1]
# file_name = file_path.split("/")[-1]
df["trait_id"] = sha256

# Store the processed data in TileDB
tiledb.from_pandas(
uri=uri,
dataframe=df,
index_dims=["chrom", "pos", "trait_id"],
mode="append",
column_types=dict_type,
ctx = ctx
uri=uri, dataframe=df, index_dims=["chrom", "pos", "trait_id"], mode="append", column_types=dict_type, ctx=cfg
)


def process_write_chunk(chunk, SNP_list, file_stream):
SNP_list_polars = pl.DataFrame(SNP_list)
chunk_polars = pl.DataFrame(chunk)
SNP_list_polars = SNP_list_polars.with_columns([
pl.col("POS").cast(pl.UInt32)])
chunk_polars = chunk_polars.with_columns([
pl.col("ALLELE0").cast(pl.Utf8),
pl.col("ALLELE1").cast(pl.Utf8),
pl.col("SNPID").cast(pl.Utf8)
])
# Perform the join operation with Polars
subset_SNPs_merge = chunk_polars.join(
SNP_list_polars,
on=['CHR', 'POS', 'ALLELE0', 'ALLELE1'],
how="inner"
SNP_list_polars = SNP_list_polars.with_columns([pl.col("POS").cast(pl.UInt32)])
chunk_polars = chunk_polars.with_columns(
[pl.col("ALLELE0").cast(pl.Utf8), pl.col("ALLELE1").cast(pl.Utf8), pl.col("SNPID").cast(pl.Utf8)]
)
#Append the merged chunk to CSV
# Perform the join operation with Polars
subset_SNPs_merge = chunk_polars.join(SNP_list_polars, on=["CHR", "POS", "ALLELE0", "ALLELE1"], how="inner")
# Append the merged chunk to CSV
subset_SNPs_merge.write_csv(file_stream)
Loading

0 comments on commit d5ad3c4

Please sign in to comment.