Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 111 additions & 55 deletions gnomad_qc/v5/annotations/generate_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

from gnomad_qc.resource_utils import check_resource_existence
from gnomad_qc.v3.utils import hom_alt_depletion_fix
from gnomad_qc.v4.resources.annotations import get_freq as get_v4_freq
from gnomad_qc.v4.resources.release import release_sites
from gnomad_qc.v5.annotations.annotation_utils import annotate_adj_no_dp
from gnomad_qc.v5.resources.annotations import (
coverage_and_an_path,
Expand Down Expand Up @@ -97,7 +97,7 @@ def mt_hist_fields(mt: hl.MatrixTable) -> hl.StructExpression:


def _prepare_aou_vds(
aou_vds: hl.vds.VariantDataset, test: bool = False
aou_vds: hl.vds.VariantDataset, test: bool = False, environment: str = "rwb"
) -> hl.vds.VariantDataset:
"""
Prepare AoU VDS for frequency calculations.
Expand All @@ -106,52 +106,67 @@ def _prepare_aou_vds(
:param test: Whether running in test mode.
:return: Prepared AoU VariantDataset.
"""
logger.info(f"Using test mode: {test}")
logger.info("Splitting multiallelics in AoU VDS...")
aou_vds = hl.vds.split_multi(aou_vds, filter_changed_loci=True)

aou_vmt = aou_vds.variant_data
# Use existing AoU group membership table and filter to variant samples.
logger.info(
"Loading AoU group membership table for variant frequency stratification..."
)
group_membership_ht = group_membership(test=test, data_set="aou").ht()
group_membership_globals = group_membership_ht.index_globals()
# Ploidy is already adjusted in the AoU VDS because of DRAGEN, do not need
# to adjust it here.
aou_vmt = aou_vmt.annotate_cols(
group_membership_ht = group_membership(
test=test, data_set="aou", environment=environment
).ht()

logger.info("Selecting cols for frequency stratification...")
aou_vmt = aou_vmt.select_cols(
sex_karyotype=aou_vmt.meta.sex_karyotype,
gen_anc=aou_vmt.meta.genetic_ancestry_inference.gen_anc,
age=aou_vmt.meta.project_meta.age,
group_membership=group_membership_ht[aou_vmt.col_key].group_membership,
)
aou_vmt = aou_vmt.annotate_globals(
logger.info("Adjusting sex ploidy...")
aou_vmt = aou_vmt.select_entries(
GT=adjusted_sex_ploidy_expr(aou_vmt.locus, aou_vmt.GT, aou_vmt.sex_karyotype),
GQ=aou_vmt.GQ,
AD=aou_vmt.AD,
)

group_membership_globals = group_membership_ht.index_globals()
aou_vmt = aou_vmt.select_globals(
freq_meta=group_membership_globals.freq_meta,
freq_meta_sample_count=group_membership_globals.freq_meta_sample_count,
age_distribution=aou_vmt.aggregate_cols(hl.agg.hist(aou_vmt.age, 30, 80, 10)),
downsamplings=group_membership_globals.downsamplings,
)

# Add adj annotation required by annotate_freq.
logger.info("Annotating adj...")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should add a note here that this is follows the same order as previous versions (annotating adj after splitting multi) but we'll need to move this annotation if we don't want to densify for freq calculations. or maybe we should add the use-all-sites-ans arg to this function to toggle behavior as needed?

aou_vmt = annotate_adj_no_dp(aou_vmt)
aou_vds = hl.vds.split_multi(hl.vds.VariantDataset(aou_vds.reference_data, aou_vmt))
aou_vds = hl.vds.VariantDataset(aou_vds.reference_data, aou_vmt)

return aou_vds


def _calculate_aou_frequencies_and_hists_using_all_sites_ans(
aou_variant_mt: hl.MatrixTable, test: bool = False
aou_variant_mt: hl.MatrixTable, test: bool = False, environment: str = "rwb"
) -> hl.Table:
"""
Calculate frequencies and age histograms for AoU variant data using all sites ANs.

:param aou_variant_mt: Prepared variant MatrixTable.
:param test: Whether to use test resources.
:param environment: Environment to use. Default is "rwb". Must be one of "rwb", "batch", or "dataproc".
:return: Table with freq and age_hists annotations.
"""
logger.info("Annotating quality metrics histograms and age histograms...")
all_sites_an_ht = coverage_and_an_path(test=test).ht()
all_sites_an_ht = coverage_and_an_path(test=test, environment=environment).ht()
aou_variant_mt = aou_variant_mt.annotate_rows(
hist_fields=mt_hist_fields(aou_variant_mt)
)

logger.info("Annotating frequencies with all sites ANs...")
group_membership_ht = group_membership(test=test, data_set="aou").ht()
group_membership_ht = group_membership(
test=test, data_set="aou", environment=environment
).ht()
aou_variant_freq_ht = agg_by_strata(
aou_variant_mt.select_entries(
"GT",
Expand Down Expand Up @@ -237,7 +252,7 @@ def _calculate_aou_frequencies_and_hists_using_densify(


def process_aou_dataset(
test: bool = False, use_all_sites_ans: bool = False
test: bool = False, use_all_sites_ans: bool = False, environment: str = "rwb"
) -> hl.Table:
"""
Process All of Us dataset for frequency calculations and age histograms.
Expand All @@ -248,36 +263,38 @@ def process_aou_dataset(

:param test: Whether to run in test mode.
:param use_all_sites_ans: Whether to use all sites ANs for frequency calculations.
:param environment: Environment to use. Default is "rwb". Must be one of "rwb", "batch", or "dataproc".
:return: Table with freq and age_hists annotations for AoU dataset.
"""
aou_vds = get_aou_vds(annotate_meta=True, release_only=True, test=test)
aou_vds = _prepare_aou_vds(aou_vds, test=test)
aou_vds = _prepare_aou_vds(aou_vds, test=test, environment=environment)

# Calculate frequencies and age histograms together
logger.info("Calculating AoU frequencies and age histograms...")
if use_all_sites_ans:
logger.info("Using all sites ANs for frequency calculations...")
aou_freq_ht = _calculate_aou_frequencies_and_hists_using_all_sites_ans(
aou_vds.variant_data, test=test
aou_vds.variant_data, test=test, environment=environment
)
else:
logger.info("Using densify for frequency calculations...")
aou_freq_ht = _calculate_aou_frequencies_and_hists_using_densify(
aou_vds, test=test
)
aou_freq_ht = select_final_dataset_fields(aou_freq_ht, dataset="aou")

return aou_freq_ht


def _prepare_consent_vds(
v4_freq_ht: hl.Table,
v4_ht: hl.Table,
test: bool = False,
test_partitions: int = 2,
) -> hl.vds.VariantDataset:
"""
Load and prepare VDS for consent withdrawal sample processing.

:param v4_freq_ht: v4 frequency table for AF annotation.
:param v4_ht: v4 release table for AF annotation.
:param test: Whether running in test mode.
:param test_partitions: Number of partitions to use in test mode. Default is 2.
:return: Prepared VDS with consent samples, split multiallelics, and annotations.
Expand Down Expand Up @@ -316,7 +333,7 @@ def _prepare_consent_vds(

# Annotate with v4 frequencies for hom alt depletion fix
vmt = vds.variant_data
vmt = vmt.annotate_rows(v4_af=v4_freq_ht[vmt.row_key].freq[0].AF)
vmt = vmt.annotate_rows(v4_af=v4_ht[vmt.row_key].freq[0].AF)

# This follows the v3/v4 genomes workflow for adj and sex adjusted genotypes which
# were added before the hom alt depletion fix.
Expand Down Expand Up @@ -371,7 +388,9 @@ def _calculate_consent_frequencies_and_age_histograms(
logger.info("Densifying VDS for frequency calculations...")
mt = hl.vds.to_dense_mt(vds)
# Group membership table is already filtered to consent drop samples.
group_membership_ht = group_membership(test=test, data_set="gnomad").ht()
group_membership_ht = group_membership(
test=test, data_set="gnomad", environment="dataproc"
).ht()

mt = mt.annotate_cols(
group_membership=group_membership_ht[mt.col_key].group_membership,
Expand Down Expand Up @@ -410,23 +429,23 @@ def _calculate_consent_frequencies_and_age_histograms(


def _subtract_consent_frequencies_and_age_histograms(
v4_freq_ht: hl.Table,
v4_ht: hl.Table,
consent_freq_ht: hl.Table,
) -> hl.Table:
"""
Subtract consent withdrawal frequencies and age histograms from v4 frequency table.

:param v4_freq_ht: v4 frequency table (contains both freq and histograms.age_hists).
:param v4_ht: v4 release table (contains both freq and histograms.age_hists).
:param consent_freq_ht: Consent withdrawal table with freq and age_hists annotations.
:return: Updated frequency table with consent frequencies and age histograms subtracted.
"""
logger.info(
"Subtracting consent withdrawal frequencies and age histograms from v4 frequency table..."
"Subtracting consent withdrawal frequencies and age histograms from v4 release table..."
)

joined_freq_ht = v4_freq_ht.annotate(
consent_freq=consent_freq_ht[v4_freq_ht.key].freq,
consent_age_hists=consent_freq_ht[v4_freq_ht.key].age_hists,
joined_freq_ht = v4_ht.annotate(
consent_freq=consent_freq_ht[v4_ht.key].freq,
consent_age_hists=consent_freq_ht[v4_ht.key].age_hists,
)

joined_freq_ht = joined_freq_ht.annotate_globals(
Expand Down Expand Up @@ -485,6 +504,36 @@ def _subtract_consent_frequencies_and_age_histograms(
return joined_freq_ht.checkpoint(new_temp_file("merged_freq_and_hists", "ht"))


def select_final_dataset_fields(ht: hl.Table, dataset: str = "gnomad") -> hl.Table:
"""
Create final freq Table with only desired annotations.

:param ht: Hail Table containing all annotations.
:return: Hail Table with final annotations.
"""
if dataset not in ["gnomad", "aou"]:
raise ValueError(f"Invalid dataset: {dataset}")

final_globals = ["freq_meta", "freq_meta_sample_count", "age_distribution"]
final_fields = ["freq", "histograms"]

if dataset == "aou":
# AoU has on extra 'downsamplings' global field that is not present in gnomAD.
final_globals.append("downsamplings")

# Convert all int64 annotations in the freq struct to int32s for merging type
# compatibility.
ht = ht.annotate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this conversion also occurs in _merge_updated_frequency_fields; maybe it only needs to happen in this function, since this function gets run for both gnomad and aou?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct! than kyou

freq=ht.freq.map(
lambda x: x.annotate(
**{k: hl.int32(v) for k, v in x.items() if v.dtype == hl.tint64}
)
)
)

return ht.select(*final_fields).select_globals(*final_globals)


def process_gnomad_dataset(
test: bool = False,
test_partitions: int = 2,
Expand All @@ -504,10 +553,10 @@ def process_gnomad_dataset(
:param test_partitions: Number of partitions to use in test mode. Default is 2.
:return: Updated frequency HT with updated frequencies and age histograms for gnomAD dataset.
"""
v4_freq_ht = get_v4_freq(data_type="genomes").ht()
v4_ht = release_sites(data_type="genomes").ht()

vds = _prepare_consent_vds(
v4_freq_ht,
v4_ht,
test=test,
test_partitions=test_partitions,
)
Expand All @@ -516,22 +565,25 @@ def process_gnomad_dataset(
consent_freq_ht = _calculate_consent_frequencies_and_age_histograms(vds, test)

if test:
v4_freq_ht = v4_freq_ht.filter(hl.is_defined(consent_freq_ht[v4_freq_ht.key]))
v4_ht = v4_ht.filter(hl.is_defined(consent_freq_ht[v4_ht.key]))
v4_ht = v4_ht.naive_coalesce(100).checkpoint(
new_temp_file("v4_ht_filtered_test", "ht")
)

logger.info("Subtracting consent frequencies and age histograms from v4...")
updated_freq_ht = _subtract_consent_frequencies_and_age_histograms(
v4_freq_ht, consent_freq_ht
v4_ht, consent_freq_ht
)

# Only overwrite fields that were actually updated (merge back with original table)
# Note: FAF/grpmax annotations will be calculated on the final merged dataset.
final_freq_ht = _merge_updated_frequency_fields(v4_freq_ht, updated_freq_ht)
# Select only the fields that were updated as FAF/grpmax/inbreeding_coeff annotations
# will be calculated on the final merged dataset.
freq_ht = _merge_updated_frequency_fields(v4_ht, updated_freq_ht)
final_freq_ht = select_final_dataset_fields(freq_ht, dataset="gnomad")

return final_freq_ht


def _merge_updated_frequency_fields(
original_freq_ht: hl.Table, updated_freq_ht: hl.Table
v4_release_ht: hl.Table, updated_freq_ht: hl.Table
) -> hl.Table:
"""
Merge frequency tables, only overwriting fields that were actually updated.
Expand All @@ -542,36 +594,28 @@ def _merge_updated_frequency_fields(
Note: FAF/grpmax/inbreeding_coeff annotations are not calculated during consent
withdrawal processing and will be calculated later on the final merged dataset.

:param original_freq_ht: Original v4 frequency table.
:param original_release_ht: Original v4 release table.
:param updated_freq_ht: Updated frequency table with consent withdrawals subtracted.
:return: Final frequency table with selective field updates.
"""
logger.info("Merging frequency tables with selective field updates...")

# Bring in updated values with a single lookup.
updated_row = updated_freq_ht[original_freq_ht.key]
updated_row = updated_freq_ht[v4_release_ht.key]

# Update freq and age_hists in a single annotate to avoid source mismatch:
# - freq: use updated if present, otherwise keep original
# - histograms.age_hists: update only age_hists, preserving qual_hists and raw_qual_hists
final_freq_ht = original_freq_ht.annotate(
freq=hl.coalesce(updated_row.freq, original_freq_ht.freq),
histograms=original_freq_ht.histograms.annotate(
final_freq_ht = v4_release_ht.annotate(
freq=hl.coalesce(updated_row.freq, v4_release_ht.freq),
histograms=v4_release_ht.histograms.annotate(
age_hists=hl.coalesce(
updated_row.histograms.age_hists,
original_freq_ht.histograms.age_hists,
v4_release_ht.histograms.age_hists,
)
),
)
# Convert all int64 annotations in the freq struct to int32s for merging type
# compatibility.
final_freq_ht = final_freq_ht.annotate(
freq=final_freq_ht.freq.map(
lambda x: x.annotate(
**{k: hl.int32(v) for k, v in x.items() if v.dtype == hl.tint64}
)
)
)

# Update globals from updated table.
updated_globals = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry another small thing I just remembered for the v4 HT is that the age global annotation is incorrect (https://the-tgg.slack.com/archives/C06H5KM9W64/p1761665354397189), so we will want to fix that here

for global_field in ["freq_meta", "freq_meta_sample_count"]:
Expand Down Expand Up @@ -647,7 +691,12 @@ def main(args):
if args.process_gnomad:
logger.info("Processing gnomAD dataset...")

gnomad_freq = get_freq(test=test, data_type="genomes", data_set="gnomad")
gnomad_freq = get_freq(
test=test,
data_type="genomes",
data_set="gnomad",
environment=environment,
)

check_resource_existence(
output_step_resources={"process-gnomad": [gnomad_freq]},
Expand All @@ -667,15 +716,22 @@ def main(args):

if args.process_aou:
logger.info("Processing All of Us dataset...")
aou_freq = get_freq(test=test, data_type="genomes", data_set="aou")
aou_freq = get_freq(
test=test,
data_type="genomes",
data_set="aou",
environment=environment,
)

check_resource_existence(
output_step_resources={"process-aou": [aou_freq]},
overwrite=overwrite,
)

aou_freq_ht = process_aou_dataset(
test=test, use_all_sites_ans=use_all_sites_ans
test=test,
use_all_sites_ans=use_all_sites_ans,
environment=environment,
)

logger.info(f"Writing AoU frequency HT to {aou_freq.path}...")
Expand Down