44"""
55
66# third party
7- import json
7+ from typing import Iterable , Sequence
88import mysql .connector
9- import numpy as np
109from math import ceil
1110
12- from queue import Queue , Empty
13- import threading
14- from multiprocessing import cpu_count
15-
16- # first party
1711import delphi .operations .secrets as secrets
1812
19- from delphi .epidata .acquisition .covidcast .logger import get_structured_logger
20-
21- class CovidcastRow ():
22- """A container for all the values of a single covidcast row."""
23-
24- @staticmethod
25- def fromCsvRowValue (row_value , source , signal , time_type , geo_type , time_value , issue , lag ):
26- if row_value is None : return None
27- return CovidcastRow (source , signal , time_type , geo_type , time_value ,
28- row_value .geo_value ,
29- row_value .value ,
30- row_value .stderr ,
31- row_value .sample_size ,
32- row_value .missing_value ,
33- row_value .missing_stderr ,
34- row_value .missing_sample_size ,
35- issue , lag )
36-
37- @staticmethod
38- def fromCsvRows (row_values , source , signal , time_type , geo_type , time_value , issue , lag ):
39- # NOTE: returns a generator, as row_values is expected to be a generator
40- return (CovidcastRow .fromCsvRowValue (row_value , source , signal , time_type , geo_type , time_value , issue , lag )
41- for row_value in row_values )
42-
43- def __init__ (self , source , signal , time_type , geo_type , time_value , geo_value , value , stderr ,
44- sample_size , missing_value , missing_stderr , missing_sample_size , issue , lag ):
45- self .id = None
46- self .source = source
47- self .signal = signal
48- self .time_type = time_type
49- self .geo_type = geo_type
50- self .time_value = time_value
51- self .geo_value = geo_value # from CSV row
52- self .value = value # ...
53- self .stderr = stderr # ...
54- self .sample_size = sample_size # ...
55- self .missing_value = missing_value # ...
56- self .missing_stderr = missing_stderr # ...
57- self .missing_sample_size = missing_sample_size # from CSV row
58- self .direction_updated_timestamp = 0
59- self .direction = None
60- self .issue = issue
61- self .lag = lag
13+ from .logger import get_structured_logger
14+ from .covidcast_row import CovidcastRow
6215
6316
6417# constants for the codes used in the `process_status` column of `signal_load`
@@ -72,25 +25,32 @@ class _PROCESS_STATUS(object):
7225class Database :
7326 """A collection of covidcast database operations."""
7427
75- DATABASE_NAME = 'covid'
76-
77- load_table = "signal_load"
78- latest_table = "signal_latest" # NOTE: careful! probably want to use variable `latest_view` instead for semantics purposes
79- latest_view = latest_table + "_v"
80- history_table = "signal_history" # NOTE: careful! probably want to use variable `history_view` instead for semantics purposes
81- history_view = history_table + "_v"
28+ def __init__ (self ):
29+ self .load_table = "signal_load"
30+ self .latest_table = "signal_latest" # NOTE: careful! probably want to use variable `latest_view` instead for semantics purposes
31+ self .latest_view = self .latest_table + "_v"
32+ self .history_table = "signal_history" # NOTE: careful! probably want to use variable `history_view` instead for semantics purposes
33+ self .history_view = self .history_table + "_v"
8234
35+ self ._connector_impl = mysql .connector
36+ self ._db_credential_user , self ._db_credential_password = secrets .db .epi
37+ self ._db_host = secrets .db .host
38+ self ._db_database = 'covid'
8339
84- def connect (self , connector_impl = mysql . connector ):
40+ def connect (self , connector_impl = None , host = None , user = None , password = None , database = None ):
8541 """Establish a connection to the database."""
42+ self ._connector_impl = connector_impl if connector_impl is not None else self ._connector_impl
43+ self ._db_host = host if host is not None else self ._db_host
44+ self ._db_credential_user = user if user is not None else self ._db_credential_user
45+ self ._db_credential_password = password if password is not None else self ._db_credential_password
46+ self ._db_database = database if database is not None else self ._db_database
8647
87- u , p = secrets .db .epi
88- self ._connector_impl = connector_impl
8948 self ._connection = self ._connector_impl .connect (
90- host = secrets .db .host ,
91- user = u ,
92- password = p ,
93- database = Database .DATABASE_NAME )
49+ host = self ._db_host ,
50+ user = self ._db_credential_user ,
51+ password = self ._db_credential_password ,
52+ database = self ._db_database
53+ )
9454 self ._cursor = self ._connection .cursor ()
9555
9656 def commit (self ):
@@ -110,7 +70,6 @@ def disconnect(self, commit):
11070 self ._connection .commit ()
11171 self ._connection .close ()
11272
113-
11473 def count_all_rows (self , tablename = None ):
11574 """Return the total number of rows in table `covidcast`."""
11675
@@ -134,11 +93,10 @@ def count_insertstatus_rows(self):
13493 for (num ,) in self ._cursor :
13594 return num
13695
137-
138- def insert_or_update_bulk (self , cc_rows ):
96+ def insert_or_update_bulk (self , cc_rows : Iterable [CovidcastRow ]):
13997 return self .insert_or_update_batch (cc_rows )
14098
141- def insert_or_update_batch (self , cc_rows , batch_size = 2 ** 20 , commit_partial = False ):
99+ def insert_or_update_batch (self , cc_rows : Sequence [ CovidcastRow ] , batch_size : int = 2 ** 20 , commit_partial : bool = False ):
142100 """
143101 Insert new rows (or update existing) into the load table.
144102 Data inserted this way will not be available to clients until the appropriate steps from src/dbjobs/ have run
@@ -476,131 +434,3 @@ def split_list(lst, n):
476434 finally :
477435 self ._cursor .execute (drop_tmp_table_sql )
478436 return total
479-
480-
481- def compute_covidcast_meta (self , table_name = None ):
482- """Compute and return metadata on all COVIDcast signals."""
483- logger = get_structured_logger ("compute_covidcast_meta" )
484-
485- if table_name is None :
486- table_name = self .latest_view
487-
488- n_threads = max (1 , cpu_count ()* 9 // 10 ) # aka number of concurrent db connections, which [sh|c]ould be ~<= 90% of the #cores available to SQL server
489- # NOTE: this may present a small problem if this job runs on different hardware than the db,
490- # but we should not run into that issue in prod.
491- logger .info (f"using { n_threads } workers" )
492-
493- srcsigs = Queue () # multi-consumer threadsafe!
494- sql = f'SELECT `source`, `signal` FROM `{ table_name } ` GROUP BY `source`, `signal` ORDER BY `source` ASC, `signal` ASC;'
495- self ._cursor .execute (sql )
496- for source , signal in self ._cursor :
497- srcsigs .put ((source , signal ))
498-
499- inner_sql = f'''
500- SELECT
501- `source` AS `data_source`,
502- `signal`,
503- `time_type`,
504- `geo_type`,
505- MIN(`time_value`) AS `min_time`,
506- MAX(`time_value`) AS `max_time`,
507- COUNT(DISTINCT `geo_value`) AS `num_locations`,
508- MIN(`value`) AS `min_value`,
509- MAX(`value`) AS `max_value`,
510- ROUND(AVG(`value`),7) AS `mean_value`,
511- ROUND(STD(`value`),7) AS `stdev_value`,
512- MAX(`value_updated_timestamp`) AS `last_update`,
513- MAX(`issue`) as `max_issue`,
514- MIN(`lag`) as `min_lag`,
515- MAX(`lag`) as `max_lag`
516- FROM
517- `{ table_name } `
518- WHERE
519- `source` = %s AND
520- `signal` = %s
521- GROUP BY
522- `time_type`,
523- `geo_type`
524- ORDER BY
525- `time_type` ASC,
526- `geo_type` ASC
527- '''
528-
529- meta = []
530- meta_lock = threading .Lock ()
531-
532- def worker ():
533- name = threading .current_thread ().name
534- logger .info ("starting thread" , thread = name )
535- # set up new db connection for thread
536- worker_dbc = Database ()
537- worker_dbc .connect (connector_impl = self ._connector_impl )
538- w_cursor = worker_dbc ._cursor
539- try :
540- while True :
541- (source , signal ) = srcsigs .get_nowait () # this will throw the Empty caught below
542- logger .info ("starting pair" , thread = name , pair = f"({ source } , { signal } )" )
543- w_cursor .execute (inner_sql , (source , signal ))
544- with meta_lock :
545- meta .extend (list (
546- dict (zip (w_cursor .column_names , x )) for x in w_cursor
547- ))
548- srcsigs .task_done ()
549- except Empty :
550- logger .info ("no jobs left, thread terminating" , thread = name )
551- finally :
552- worker_dbc .disconnect (False ) # cleanup
553-
554- threads = []
555- for n in range (n_threads ):
556- t = threading .Thread (target = worker , name = 'MetacacheThread-' + str (n ))
557- t .start ()
558- threads .append (t )
559-
560- srcsigs .join ()
561- logger .info ("jobs complete" )
562- for t in threads :
563- t .join ()
564- logger .info ("all threads terminated" )
565-
566- # sort the metadata because threaded workers dgaf
567- sorting_fields = "data_source signal time_type geo_type" .split ()
568- sortable_fields_fn = lambda x : [(field , x [field ]) for field in sorting_fields ]
569- prepended_sortables_fn = lambda x : sortable_fields_fn (x ) + list (x .items ())
570- tuple_representation = list (map (prepended_sortables_fn , meta ))
571- tuple_representation .sort ()
572- meta = list (map (dict , tuple_representation )) # back to dict form
573-
574- return meta
575-
576-
577- def update_covidcast_meta_cache (self , metadata ):
578- """Updates the `covidcast_meta_cache` table."""
579-
580- sql = '''
581- UPDATE
582- `covidcast_meta_cache`
583- SET
584- `timestamp` = UNIX_TIMESTAMP(NOW()),
585- `epidata` = %s
586- '''
587- epidata_json = json .dumps (metadata )
588-
589- self ._cursor .execute (sql , (epidata_json ,))
590-
591- def retrieve_covidcast_meta_cache (self ):
592- """Useful for viewing cache entries (was used in debugging)"""
593-
594- sql = '''
595- SELECT `epidata`
596- FROM `covidcast_meta_cache`
597- ORDER BY `timestamp` DESC
598- LIMIT 1;
599- '''
600- self ._cursor .execute (sql )
601- cache_json = self ._cursor .fetchone ()[0 ]
602- cache = json .loads (cache_json )
603- cache_hash = {}
604- for entry in cache :
605- cache_hash [(entry ['data_source' ], entry ['signal' ], entry ['time_type' ], entry ['geo_type' ])] = entry
606- return cache_hash
0 commit comments