Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cohort.get_sequencing_groups in cpg_workflows targets.py #899

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpg_workflows/metamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def get_analyses_by_sgid(
"""
dataset = dataset or self.default_dataset
metamist_proj = dataset or self.default_dataset
if get_config()['workflow']['access_level'] == 'test':
if get_config()['workflow']['access_level'] == 'test' and not metamist_proj.endswith('-test'):
metamist_proj += '-test'

analyses = query(
Expand Down
77 changes: 69 additions & 8 deletions cpg_workflows/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from cpg_utils.config import dataset_path, get_config, reference_path, web_url

from .filetypes import AlignmentInput, BamPath, CramPath, FastqPairs, GvcfPath
from .metamist import Assay
from .metamist import Assay, get_metamist


class Target:
Expand Down Expand Up @@ -242,6 +242,7 @@ def __init__(self, name: str | None = None, multicohort: MultiCohort | None = No
super().__init__()
self.name = name or get_config()['workflow']['dataset']
self.analysis_dataset = Dataset(name=get_config()['workflow']['dataset'], cohort=self)
self._sequencing_group_by_id: dict[str, SequencingGroup] = {}
self._datasets_by_name: dict[str, Dataset] = {}
self.multicohort = multicohort

Expand Down Expand Up @@ -291,15 +292,73 @@ def get_dataset_by_name(self, name: str, only_active: bool = True) -> Optional['
ds_by_name = {d.name: d for d in self.get_datasets(only_active)}
return ds_by_name.get(name)

def get_sequencing_group_ids(self, only_active: bool = True) -> list['SequencingGroup']:
"""
Gets a flat list of all sequencing groups from all datasets in the cohort.
"""
sequencing_group_ids = []
cohort_sgs_by_project = get_metamist().get_sgs_by_project_from_cohort(self.name)
for ds in cohort_sgs_by_project:
sequencing_group_ids.extend(cohort_sgs_by_project[ds])
return sequencing_group_ids

def add_sequencing_group(
self,
id: str, # pylint: disable=redefined-builtin
*,
sequencing_type: str,
sequencing_technology: str,
sequencing_platform: str,
external_id: str | None = None,
participant_id: str | None = None,
meta: dict | None = None,
sex: Optional['Sex'] = None,
pedigree: Optional['PedigreeInfo'] = None,
alignment_input: AlignmentInput | None = None,
) -> 'SequencingGroup':
"""
Create a new sequencing group and add it to the cohort.
"""
if id in self._sequencing_group_by_id:
logging.debug(f'SequencingGroup {id} already exists in the cohort {self.name}')
return self._sequencing_group_by_id[id]

force_sgs = get_config()['workflow'].get('force_sgs', set())
forced = id in force_sgs or external_id in force_sgs or participant_id in force_sgs

s = SequencingGroup(
id=id,
dataset=self,
external_id=external_id,
sequencing_type=sequencing_type,
sequencing_technology=sequencing_technology,
sequencing_platform=sequencing_platform,
participant_id=participant_id,
meta=meta,
sex=sex,
pedigree=pedigree,
alignment_input=alignment_input,
forced=forced,
)
self._sequencing_group_by_id[id] = s
return s

def add_sequencing_group_object(self, s: 'SequencingGroup'):
"""
Add a sequencing group object to the cohort.
Args:
s: SequencingGroup object
"""
if s.id in self._sequencing_group_by_id:
logging.debug(f'SequencingGroup {s.id} already exists in the cohort {self.name}')
return self._sequencing_group_by_id[s.id]
self._sequencing_group_by_id[s.id] = s

def get_sequencing_groups(self, only_active: bool = True) -> list['SequencingGroup']:
"""
Gets a flat list of all sequencing groups from all datasets.
Include only "active" sequencing groups (unless only_active is False)
Get cohort's sequencing groups. Include only "active" sequencing groups, unless only_active=False
"""
all_sequencing_groups = []
for ds in self.get_datasets(only_active=False):
all_sequencing_groups.extend(ds.get_sequencing_groups(only_active=only_active))
return all_sequencing_groups
return [s for sid, s in self._sequencing_group_by_id.items() if (s.active or not only_active)]

def add_dataset(self, dataset: 'Dataset') -> 'Dataset':
"""
Expand Down Expand Up @@ -528,7 +587,9 @@ def get_sequencing_groups(self, only_active: bool = True) -> list['SequencingGro
"""
Get dataset's sequencing groups. Include only "active" sequencing groups, unless only_active=False
"""
return [s for sid, s in self._sequencing_group_by_id.items() if (s.active or not only_active)]
if not self.cohort:
return [s for sid, s in self._sequencing_group_by_id.items() if (s.active or not only_active)]
return [s for sid, s in self._sequencing_group_by_id.items() if (s.active or not only_active) and sid in self.cohort.get_sequencing_group_ids()]

def get_job_attrs(self) -> dict:
"""
Expand Down
Loading