Skip to content

Commit 440fa05

Browse files
committed
Add JIT meta caching
* separate meta operations from database.py into database_meta.py * add tests
1 parent 3ed17a8 commit 440fa05

File tree

11 files changed

+3779
-294
lines changed

11 files changed

+3779
-294
lines changed

integrations/acquisition/covidcast/test_covidcast_meta_caching.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010

1111
# first party
1212
from delphi_utils import Nans
13-
from delphi.epidata.client.delphi_epidata import Epidata
1413
import delphi.operations.secrets as secrets
15-
import delphi.epidata.acquisition.covidcast.database as live
16-
from delphi.epidata.acquisition.covidcast.covidcast_meta_cache_updater import main
14+
from ....src.client.delphi_epidata import Epidata
15+
from ....src.acquisition.covidcast.database_meta import DatabaseMeta
16+
from ....src.acquisition.covidcast.covidcast_meta_cache_updater import main
1717

1818
# py3tester coverage target (equivalent to `import *`)
1919
__test_target__ = (
@@ -92,7 +92,7 @@ def test_caching(self):
9292
self.cnx.commit()
9393

9494
# make sure the live utility is serving something sensible
95-
cvc_database = live.Database()
95+
cvc_database = DatabaseMeta()
9696
cvc_database.connect()
9797
epidata1 = cvc_database.compute_covidcast_meta()
9898
cvc_database.disconnect(False)

integrations/acquisition/covidcast/test_db.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import unittest
22

33
from delphi_utils import Nans
4-
from delphi.epidata.acquisition.covidcast.database import Database, CovidcastRow
54
import delphi.operations.secrets as secrets
65

6+
from ....src.acquisition.covidcast.database import Database, CovidcastRow
7+
78
# all the Nans we use here are just one value, so this is a shortcut to it:
89
nmv = Nans.NOT_MISSING.value
910

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# TODO: Fill these in.
2+
GEO_TYPES = ["county", "state", "hhs", "msa", "nation", "hrr"]
3+
ALL_TIME = "19000101-20500101"

src/acquisition/covidcast/covidcast_meta_cache_updater.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import time
77

88
# first party
9-
from delphi.epidata.acquisition.covidcast.database import Database
9+
from .database_meta import DatabaseMeta
1010
from delphi.epidata.acquisition.covidcast.logger import get_structured_logger
1111
from delphi.epidata.client.delphi_epidata import Epidata
1212

@@ -18,7 +18,7 @@ def get_argument_parser():
1818
return parser
1919

2020

21-
def main(args, epidata_impl=Epidata, database_impl=Database):
21+
def main(args, epidata_impl: Epidata = Epidata, database_impl: DatabaseMeta = DatabaseMeta):
2222
"""Update the covidcast metadata cache.
2323
2424
`args`: parsed command-line arguments

src/acquisition/covidcast/database.py

Lines changed: 26 additions & 196 deletions
Original file line numberDiff line numberDiff line change
@@ -4,61 +4,14 @@
44
"""
55

66
# third party
7-
import json
7+
from typing import Iterable, Sequence
88
import mysql.connector
9-
import numpy as np
109
from math import ceil
1110

12-
from queue import Queue, Empty
13-
import threading
14-
from multiprocessing import cpu_count
15-
16-
# first party
1711
import delphi.operations.secrets as secrets
1812

19-
from delphi.epidata.acquisition.covidcast.logger import get_structured_logger
20-
21-
class CovidcastRow():
22-
"""A container for all the values of a single covidcast row."""
23-
24-
@staticmethod
25-
def fromCsvRowValue(row_value, source, signal, time_type, geo_type, time_value, issue, lag):
26-
if row_value is None: return None
27-
return CovidcastRow(source, signal, time_type, geo_type, time_value,
28-
row_value.geo_value,
29-
row_value.value,
30-
row_value.stderr,
31-
row_value.sample_size,
32-
row_value.missing_value,
33-
row_value.missing_stderr,
34-
row_value.missing_sample_size,
35-
issue, lag)
36-
37-
@staticmethod
38-
def fromCsvRows(row_values, source, signal, time_type, geo_type, time_value, issue, lag):
39-
# NOTE: returns a generator, as row_values is expected to be a generator
40-
return (CovidcastRow.fromCsvRowValue(row_value, source, signal, time_type, geo_type, time_value, issue, lag)
41-
for row_value in row_values)
42-
43-
def __init__(self, source, signal, time_type, geo_type, time_value, geo_value, value, stderr,
44-
sample_size, missing_value, missing_stderr, missing_sample_size, issue, lag):
45-
self.id = None
46-
self.source = source
47-
self.signal = signal
48-
self.time_type = time_type
49-
self.geo_type = geo_type
50-
self.time_value = time_value
51-
self.geo_value = geo_value # from CSV row
52-
self.value = value # ...
53-
self.stderr = stderr # ...
54-
self.sample_size = sample_size # ...
55-
self.missing_value = missing_value # ...
56-
self.missing_stderr = missing_stderr # ...
57-
self.missing_sample_size = missing_sample_size # from CSV row
58-
self.direction_updated_timestamp = 0
59-
self.direction = None
60-
self.issue = issue
61-
self.lag = lag
13+
from .logger import get_structured_logger
14+
from .covidcast_row import CovidcastRow
6215

6316

6417
# constants for the codes used in the `process_status` column of `signal_load`
@@ -72,25 +25,32 @@ class _PROCESS_STATUS(object):
7225
class Database:
7326
"""A collection of covidcast database operations."""
7427

75-
DATABASE_NAME = 'covid'
76-
77-
load_table = "signal_load"
78-
latest_table = "signal_latest" # NOTE: careful! probably want to use variable `latest_view` instead for semantics purposes
79-
latest_view = latest_table + "_v"
80-
history_table = "signal_history" # NOTE: careful! probably want to use variable `history_view` instead for semantics purposes
81-
history_view = history_table + "_v"
28+
def __init__(self):
29+
self.load_table = "signal_load"
30+
self.latest_table = "signal_latest" # NOTE: careful! probably want to use variable `latest_view` instead for semantics purposes
31+
self.latest_view = self.latest_table + "_v"
32+
self.history_table = "signal_history" # NOTE: careful! probably want to use variable `history_view` instead for semantics purposes
33+
self.history_view = self.history_table + "_v"
8234

35+
self._connector_impl = mysql.connector
36+
self._db_credential_user, self._db_credential_password = secrets.db.epi
37+
self._db_host = secrets.db.host
38+
self._db_database = 'covid'
8339

84-
def connect(self, connector_impl=mysql.connector):
40+
def connect(self, connector_impl=None, host=None, user=None, password=None, database=None):
8541
"""Establish a connection to the database."""
42+
self._connector_impl = connector_impl if connector_impl is not None else self._connector_impl
43+
self._db_host = host if host is not None else self._db_host
44+
self._db_credential_user = user if user is not None else self._db_credential_user
45+
self._db_credential_password = password if password is not None else self._db_credential_password
46+
self._db_database = database if database is not None else self._db_database
8647

87-
u, p = secrets.db.epi
88-
self._connector_impl = connector_impl
8948
self._connection = self._connector_impl.connect(
90-
host=secrets.db.host,
91-
user=u,
92-
password=p,
93-
database=Database.DATABASE_NAME)
49+
host=self._db_host,
50+
user=self._db_credential_user,
51+
password=self._db_credential_password,
52+
database=self._db_database
53+
)
9454
self._cursor = self._connection.cursor()
9555

9656
def commit(self):
@@ -110,7 +70,6 @@ def disconnect(self, commit):
11070
self._connection.commit()
11171
self._connection.close()
11272

113-
11473
def count_all_rows(self, tablename=None):
11574
"""Return the total number of rows in table `covidcast`."""
11675

@@ -134,11 +93,10 @@ def count_insertstatus_rows(self):
13493
for (num,) in self._cursor:
13594
return num
13695

137-
138-
def insert_or_update_bulk(self, cc_rows):
96+
def insert_or_update_bulk(self, cc_rows: Iterable[CovidcastRow]):
13997
return self.insert_or_update_batch(cc_rows)
14098

141-
def insert_or_update_batch(self, cc_rows, batch_size=2**20, commit_partial=False):
99+
def insert_or_update_batch(self, cc_rows: Sequence[CovidcastRow], batch_size: int = 2**20, commit_partial: bool = False):
142100
"""
143101
Insert new rows (or update existing) into the load table.
144102
Data inserted this way will not be available to clients until the appropriate steps from src/dbjobs/ have run
@@ -476,131 +434,3 @@ def split_list(lst, n):
476434
finally:
477435
self._cursor.execute(drop_tmp_table_sql)
478436
return total
479-
480-
481-
def compute_covidcast_meta(self, table_name=None):
482-
"""Compute and return metadata on all COVIDcast signals."""
483-
logger = get_structured_logger("compute_covidcast_meta")
484-
485-
if table_name is None:
486-
table_name = self.latest_view
487-
488-
n_threads = max(1, cpu_count()*9//10) # aka number of concurrent db connections, which [sh|c]ould be ~<= 90% of the #cores available to SQL server
489-
# NOTE: this may present a small problem if this job runs on different hardware than the db,
490-
# but we should not run into that issue in prod.
491-
logger.info(f"using {n_threads} workers")
492-
493-
srcsigs = Queue() # multi-consumer threadsafe!
494-
sql = f'SELECT `source`, `signal` FROM `{table_name}` GROUP BY `source`, `signal` ORDER BY `source` ASC, `signal` ASC;'
495-
self._cursor.execute(sql)
496-
for source, signal in self._cursor:
497-
srcsigs.put((source, signal))
498-
499-
inner_sql = f'''
500-
SELECT
501-
`source` AS `data_source`,
502-
`signal`,
503-
`time_type`,
504-
`geo_type`,
505-
MIN(`time_value`) AS `min_time`,
506-
MAX(`time_value`) AS `max_time`,
507-
COUNT(DISTINCT `geo_value`) AS `num_locations`,
508-
MIN(`value`) AS `min_value`,
509-
MAX(`value`) AS `max_value`,
510-
ROUND(AVG(`value`),7) AS `mean_value`,
511-
ROUND(STD(`value`),7) AS `stdev_value`,
512-
MAX(`value_updated_timestamp`) AS `last_update`,
513-
MAX(`issue`) as `max_issue`,
514-
MIN(`lag`) as `min_lag`,
515-
MAX(`lag`) as `max_lag`
516-
FROM
517-
`{table_name}`
518-
WHERE
519-
`source` = %s AND
520-
`signal` = %s
521-
GROUP BY
522-
`time_type`,
523-
`geo_type`
524-
ORDER BY
525-
`time_type` ASC,
526-
`geo_type` ASC
527-
'''
528-
529-
meta = []
530-
meta_lock = threading.Lock()
531-
532-
def worker():
533-
name = threading.current_thread().name
534-
logger.info("starting thread", thread=name)
535-
# set up new db connection for thread
536-
worker_dbc = Database()
537-
worker_dbc.connect(connector_impl=self._connector_impl)
538-
w_cursor = worker_dbc._cursor
539-
try:
540-
while True:
541-
(source, signal) = srcsigs.get_nowait() # this will throw the Empty caught below
542-
logger.info("starting pair", thread=name, pair=f"({source}, {signal})")
543-
w_cursor.execute(inner_sql, (source, signal))
544-
with meta_lock:
545-
meta.extend(list(
546-
dict(zip(w_cursor.column_names, x)) for x in w_cursor
547-
))
548-
srcsigs.task_done()
549-
except Empty:
550-
logger.info("no jobs left, thread terminating", thread=name)
551-
finally:
552-
worker_dbc.disconnect(False) # cleanup
553-
554-
threads = []
555-
for n in range(n_threads):
556-
t = threading.Thread(target=worker, name='MetacacheThread-'+str(n))
557-
t.start()
558-
threads.append(t)
559-
560-
srcsigs.join()
561-
logger.info("jobs complete")
562-
for t in threads:
563-
t.join()
564-
logger.info("all threads terminated")
565-
566-
# sort the metadata because threaded workers dgaf
567-
sorting_fields = "data_source signal time_type geo_type".split()
568-
sortable_fields_fn = lambda x: [(field, x[field]) for field in sorting_fields]
569-
prepended_sortables_fn = lambda x: sortable_fields_fn(x) + list(x.items())
570-
tuple_representation = list(map(prepended_sortables_fn, meta))
571-
tuple_representation.sort()
572-
meta = list(map(dict, tuple_representation)) # back to dict form
573-
574-
return meta
575-
576-
577-
def update_covidcast_meta_cache(self, metadata):
578-
"""Updates the `covidcast_meta_cache` table."""
579-
580-
sql = '''
581-
UPDATE
582-
`covidcast_meta_cache`
583-
SET
584-
`timestamp` = UNIX_TIMESTAMP(NOW()),
585-
`epidata` = %s
586-
'''
587-
epidata_json = json.dumps(metadata)
588-
589-
self._cursor.execute(sql, (epidata_json,))
590-
591-
def retrieve_covidcast_meta_cache(self):
592-
"""Useful for viewing cache entries (was used in debugging)"""
593-
594-
sql = '''
595-
SELECT `epidata`
596-
FROM `covidcast_meta_cache`
597-
ORDER BY `timestamp` DESC
598-
LIMIT 1;
599-
'''
600-
self._cursor.execute(sql)
601-
cache_json = self._cursor.fetchone()[0]
602-
cache = json.loads(cache_json)
603-
cache_hash = {}
604-
for entry in cache:
605-
cache_hash[(entry['data_source'], entry['signal'], entry['time_type'], entry['geo_type'])] = entry
606-
return cache_hash

0 commit comments

Comments
 (0)