Skip to content

Commit a41ce70

Browse files
committed
move save data logic to shared helper
1 parent f670037 commit a41ce70

File tree

3 files changed

+48
-48
lines changed

3 files changed

+48
-48
lines changed

seqr/views/apis/data_manager_api.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import base64
22
from collections import defaultdict
33
from datetime import datetime
4-
import gzip
54
import json
65
import os
76
import requests
@@ -26,7 +25,7 @@
2625
from seqr.views.utils.airtable_utils import AirtableSession, LOADABLE_PDO_STATUSES, AVAILABLE_PDO_STATUS
2726
from seqr.views.utils.dataset_utils import load_rna_seq, load_phenotype_prioritization_data_file, RNA_DATA_TYPE_CONFIGS, \
2827
post_process_rna_data, convert_django_meta_to_http_headers
29-
from seqr.views.utils.file_utils import get_temp_file_path, load_uploaded_file, persist_temp_file
28+
from seqr.views.utils.file_utils import get_temp_file_path, load_uploaded_file
3029
from seqr.views.utils.json_utils import create_json_response
3130
from seqr.views.utils.json_to_orm_utils import update_model_from_json
3231
from seqr.views.utils.pedigree_info_utils import get_validated_related_individuals, JsonConstants
@@ -72,37 +71,15 @@ def update_rna_seq(request):
7271
if uploaded_mapping_file_id:
7372
mapping_file = load_uploaded_file(uploaded_mapping_file_id)
7473

75-
file_name_prefix = f'rna_sample_data__{data_type}__{datetime.now().isoformat()}'
76-
file_dir = get_temp_file_path(file_name_prefix, is_local=True)
77-
os.mkdir(file_dir)
78-
79-
sample_files = {}
80-
81-
def _save_sample_data(sample_key, sample_data):
82-
if sample_key not in sample_files:
83-
file_name = _get_sample_file_path(file_dir, '_'.join(sample_key))
84-
sample_files[sample_key] = gzip.open(file_name, 'at')
85-
sample_files[sample_key].write(f'{json.dumps(sample_data)}\n')
86-
8774
try:
88-
sample_guids_to_keys, info, warnings = load_rna_seq(
89-
data_type, file_path, _save_sample_data,
75+
sample_guids_to_keys, file_name_prefix, info, warnings = load_rna_seq(
76+
data_type, file_path,
9077
user=request.user, mapping_file=mapping_file, ignore_extra_samples=request_json.get('ignoreExtraSamples'))
9178
except FileNotFoundError:
9279
return create_json_response({'error': 'File not found: {}'.format(file_path)}, status=400)
9380
except ValueError as e:
9481
return create_json_response({'error': str(e)}, status=400)
9582

96-
for sample_guid, sample_key in sample_guids_to_keys.items():
97-
sample_files[sample_key].close() # Required to ensure gzipped files are properly terminated
98-
os.rename(
99-
_get_sample_file_path(file_dir, '_'.join(sample_key)),
100-
_get_sample_file_path(file_dir, sample_guid),
101-
)
102-
103-
if sample_guids_to_keys:
104-
persist_temp_file(file_name_prefix, request.user)
105-
10683
return create_json_response({
10784
'info': info,
10885
'warnings': warnings,

seqr/views/apis/data_manager_api_tests.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -793,12 +793,12 @@ def test_update_rna_splice_outlier(self, *args, **kwargs):
793793
@mock.patch('seqr.views.utils.file_utils.tempfile.gettempdir', lambda: 'tmp/')
794794
@mock.patch('seqr.utils.communication_utils.send_html_email')
795795
@mock.patch('seqr.utils.communication_utils.safe_post_to_slack')
796-
@mock.patch('seqr.views.apis.data_manager_api.datetime')
797-
@mock.patch('seqr.views.apis.data_manager_api.os.mkdir')
798-
@mock.patch('seqr.views.apis.data_manager_api.os.rename')
796+
@mock.patch('seqr.views.utils.dataset_utils.datetime')
797+
@mock.patch('seqr.views.utils.dataset_utils.os.mkdir')
798+
@mock.patch('seqr.views.utils.dataset_utils.os.rename')
799799
@mock.patch('seqr.views.apis.data_manager_api.load_uploaded_file')
800800
@mock.patch('seqr.utils.file_utils.subprocess.Popen')
801-
@mock.patch('seqr.views.apis.data_manager_api.gzip.open')
801+
@mock.patch('seqr.views.utils.dataset_utils.gzip.open')
802802
def _test_update_rna_seq(self, data_type, mock_open, mock_subprocess, mock_load_uploaded_file,
803803
mock_rename, mock_mkdir, mock_datetime, mock_send_slack, mock_send_email):
804804
url = reverse(update_rna_seq)
@@ -906,7 +906,7 @@ def _test_basic_data_loading(data, num_parsed_samples, num_loaded_samples, new_s
906906
f'Attempted data loading for {num_loaded_samples} RNA-seq samples in the following {num_projects}'
907907
f' projects: {project_names}'
908908
]
909-
file_name = RNA_FILENAME_TEMPLATE.format(data_type)
909+
file_name = RNA_FILENAME_TEMPLATE.format(params['data_type'])
910910
response_json = response.json()
911911
self.assertDictEqual(response_json, {'info': info, 'warnings': warnings or [], 'sampleGuids': mock.ANY,
912912
'fileName': file_name})
@@ -974,7 +974,7 @@ def _test_basic_data_loading(data, num_parsed_samples, num_loaded_samples, new_s
974974
self.assertSetEqual(set(response_json['sampleGuids']), {sample_guid, new_sample_guid})
975975

976976
# test correct file interactions
977-
file_path = RNA_FILENAME_TEMPLATE.format(data_type)
977+
file_path = RNA_FILENAME_TEMPLATE.format(params['data_type'])
978978
expected_subprocess_calls = [
979979
f'gsutil ls {RNA_FILE_ID}',
980980
f'gsutil cat {RNA_FILE_ID} | gunzip -c -q - ',
@@ -1027,7 +1027,7 @@ def _test_basic_data_loading(data, num_parsed_samples, num_loaded_samples, new_s
10271027
self.assertTrue(second_tissue_sample_guid != new_sample_guid)
10281028
self.assertTrue(second_tissue_sample_guid in response_json['sampleGuids'])
10291029
self._assert_expected_file_open(mock_rename, mock_open, [
1030-
f'tmp/temp_uploads/{RNA_FILENAME_TEMPLATE.format(data_type)}/{sample_guid}.json.gz'
1030+
f'tmp/temp_uploads/{RNA_FILENAME_TEMPLATE.format(params["data_type"])}/{sample_guid}.json.gz'
10311031
for sample_guid in response_json['sampleGuids']
10321032
])
10331033
self.assertSetEqual(

seqr/views/utils/dataset_utils.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from collections import defaultdict
2+
from datetime import datetime
23
from django.contrib.postgres.aggregates import ArrayAgg
34
from django.db.models import F, Q
45
from django.utils import timezone
6+
import gzip
7+
import json
8+
import os
59
from tqdm import tqdm
610

711
from seqr.models import Sample, Individual, Family, Project, RnaSample, RnaSeqOutlier, RnaSeqTpm, RnaSeqSpliceOutlier
@@ -10,7 +14,7 @@
1014
from seqr.utils.logging_utils import SeqrLogger
1115
from seqr.utils.middleware import ErrorsWarningsException
1216
from seqr.utils.xpos_utils import format_chrom
13-
from seqr.views.utils.file_utils import parse_file
17+
from seqr.views.utils.file_utils import parse_file, get_temp_file_path, persist_temp_file
1418
from seqr.views.utils.permissions_utils import get_internal_projects
1519
from seqr.views.utils.json_utils import _to_snake_case, _to_camel_case
1620
from reference_data.models import GeneInfo
@@ -321,12 +325,6 @@ def _get_splice_id(row):
321325
}
322326

323327

324-
# TODO
325-
def load_rna_seq(data_type, *args, **kwargs):
326-
config = RNA_DATA_TYPE_CONFIGS[data_type]
327-
return _load_rna_seq(config['model_class'], config['data_type'], *args, config['columns'], **config['additional_kwargs'], **kwargs)
328-
329-
330328
def _validate_rna_header(header, column_map):
331329
required_column_map = {
332330
column_map.get(col, col): col for col in [SAMPLE_ID_COL, PROJECT_COL, GENE_ID_COL, TISSUE_COL]
@@ -340,7 +338,7 @@ def _validate_rna_header(header, column_map):
340338

341339

342340
def _load_rna_seq_file(
343-
file_path, data_source, user, data_type, model_cls, potential_samples, save_data, individual_data_by_key,
341+
file_path, data_source, user, data_type, model_cls, potential_samples, sample_files, file_dir, individual_data_by_key,
344342
column_map, mapping_file=None, allow_missing_gene=False, ignore_extra_samples=False,
345343
):
346344
sample_id_to_individual_id_mapping = {}
@@ -364,7 +362,7 @@ def _load_rna_seq_file(
364362
_parse_rna_row(
365363
dict(zip(header, line)), column_map, required_column_map, missing_required_fields,
366364
sample_id_to_individual_id_mapping, potential_samples, loaded_samples, gene_ids, sample_guid_keys_to_load,
367-
samples_to_create, unmatched_samples, individual_data_by_key, save_data, ignore_extra_samples,
365+
samples_to_create, unmatched_samples, individual_data_by_key, sample_files, file_dir, ignore_extra_samples,
368366
)
369367

370368
errors, warnings = _process_rna_errors(
@@ -384,7 +382,7 @@ def _load_rna_seq_file(
384382

385383
def _parse_rna_row(row, column_map, required_column_map, missing_required_fields, sample_id_to_individual_id_mapping,
386384
potential_samples, loaded_samples, gene_ids, sample_guid_keys_to_load, samples_to_create,
387-
unmatched_samples, individual_data_by_key, save_data, ignore_extra_samples):
385+
unmatched_samples, individual_data_by_key, sample_files, file_dir, ignore_extra_samples):
388386
row_dict = {mapped_key: row[col] for mapped_key, col in column_map.items()}
389387

390388
missing_cols = {col_id for col, col_id in required_column_map.items() if not row.get(col)}
@@ -424,7 +422,14 @@ def _parse_rna_row(row, column_map, required_column_map, missing_required_fields
424422

425423
for gene_id in row_gene_ids:
426424
row_dict = {**row_dict, GENE_ID_COL: gene_id}
427-
save_data(sample_key, row_dict)
425+
if sample_key not in sample_files:
426+
file_name = _get_sample_file_path(file_dir, '_'.join(sample_key))
427+
sample_files[sample_key] = gzip.open(file_name, 'at')
428+
sample_files[sample_key].write(f'{json.dumps(row_dict)}\n')
429+
430+
431+
def _get_sample_file_path(file_dir, sample_guid):
432+
return os.path.join(file_dir, f'{sample_guid}.json.gz')
428433

429434

430435
def _process_rna_errors(gene_ids, missing_required_fields, unmatched_samples, ignore_extra_samples, loaded_samples):
@@ -492,7 +497,10 @@ def _match_new_sample(sample_key, samples_to_create, unmatched_samples, individu
492497
unmatched_samples.add(sample_key)
493498

494499

495-
def _load_rna_seq(model_cls, data_type, file_path, save_data, *args, user=None, **kwargs):
500+
def load_rna_seq(data_type, file_path, user, **kwargs):
501+
config = RNA_DATA_TYPE_CONFIGS[data_type]
502+
data_type = config['data_type']
503+
model_cls = config['model_class']
496504
projects = get_internal_projects()
497505
data_source = file_path.split('/')[-1].split('_-_')[-1]
498506

@@ -503,8 +511,14 @@ def _load_rna_seq(model_cls, data_type, file_path, save_data, *args, user=None,
503511
)
504512
individual_data_by_key = _get_individuals_by_key(projects)
505513

514+
sample_files = {}
515+
file_name_prefix = f'rna_sample_data__{data_type}__{datetime.now().isoformat()}'
516+
file_dir = get_temp_file_path(file_name_prefix, is_local=True)
517+
os.mkdir(file_dir)
518+
506519
warnings, not_loaded_count, sample_guid_keys_to_load, prev_loaded_individual_ids = _load_rna_seq_file(
507-
file_path, data_source, user, data_type, model_cls, potential_samples, save_data, individual_data_by_key, *args, **kwargs)
520+
file_path, data_source, user, data_type, model_cls, potential_samples, sample_files, file_dir, individual_data_by_key,
521+
config['columns'], **config['additional_kwargs'], **kwargs)
508522
message = f'Parsed {len(sample_guid_keys_to_load) + not_loaded_count} RNA-seq samples'
509523
info = [message]
510524
logger.info(message, user)
@@ -524,10 +538,19 @@ def _load_rna_seq(model_cls, data_type, file_path, save_data, *args, user=None,
524538
for warning in warnings:
525539
logger.warning(warning, user)
526540

527-
return sample_guid_keys_to_load, info, warnings
541+
for sample_guid, sample_key in sample_guid_keys_to_load.items():
542+
sample_files[sample_key].close() # Required to ensure gzipped files are properly terminated
543+
os.rename(
544+
_get_sample_file_path(file_dir, '_'.join(sample_key)),
545+
_get_sample_file_path(file_dir, sample_guid),
546+
)
547+
548+
if sample_guid_keys_to_load:
549+
persist_temp_file(file_name_prefix, user)
550+
551+
return sample_guid_keys_to_load, file_name_prefix, info, warnings
528552

529553

530-
# TODO
531554
def post_process_rna_data(sample_guid, data, get_unique_key=None, format_fields=None):
532555
mismatches = set()
533556
invalid_format_fields = defaultdict(set)

0 commit comments

Comments
 (0)