Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update init scripts to use new data refresh #4962

Merged
merged 11 commits into from
Oct 8, 2024
Merged
5 changes: 4 additions & 1 deletion .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ jobs:
uses: ./.github/actions/load-img
with:
run_id: ${{ github.run_id }}
setup_images: upstream_db ingestion_server
setup_images: upstream_db ingestion_server catalog

# Sets build args specifying versions needed to build Docker image.
- name: Prepare build args
Expand Down Expand Up @@ -486,6 +486,9 @@ jobs:
API_PY_VERSION=${{ steps.prepare-build-args.outputs.api_py_version }}
PDM_INSTALL_ARGS=--dev

- name: Start Catalog
run: just catalog/up

- name: Start API, ingest and index test data
run: just api/init

Expand Down
29 changes: 21 additions & 8 deletions catalog/dags/common/sensors/single_run_external_dags_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,25 @@ class SingleRunExternalDAGsSensor(BaseSensorOperator):
:param external_dag_ids: A list of dag_ids that you want to wait for
:param check_existence: Set to `True` to check if the external DAGs exist,
and immediately cease waiting if not (default value: False).
:param allow_concurrent_runs: Used to force the Sensor to pass, even
if there are concurrent runs.
"""

def __init__(
self,
*,
external_dag_ids: Iterable[str],
check_existence: bool = False,
allow_concurrent_runs: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.external_dag_ids = external_dag_ids
self.check_existence = check_existence
self._has_checked_existence = False
self.allow_concurrent_runs = allow_concurrent_runs

# Used to ensure some checks are only evaluated on the first poke
self._has_checked_params = False

@provide_session
def poke(self, context, session=None):
Expand All @@ -42,19 +48,27 @@ def poke(self, context, session=None):
self.external_dag_ids,
)

if self.check_existence:
self._check_for_existence(session=session)
if not self._has_checked_params:
if self.allow_concurrent_runs:
self.log.info(
"`allow_concurrent_runs` is enabled. Returning without"
" checking for running DAGs."
)
return True

if self.check_existence:
self._check_for_existence(session=session)

# Only check DAG existence and `allow_concurrent_runs`
# on the first execution.
self._has_checked_params = True

count_running = self.get_count(session)

self.log.info("%s DAGs are in the running state", count_running)
return count_running == 0

def _check_for_existence(self, session) -> None:
# Check DAG existence only once, on the first execution.
if self._has_checked_existence:
return

for dag_id in self.external_dag_ids:
dag_to_wait = (
session.query(DagModel).filter(DagModel.dag_id == dag_id).first()
Expand All @@ -72,7 +86,6 @@ def _check_for_existence(self, session) -> None:
f"The external DAG {dag_id} does not have a task "
f"with id {self.task_id}."
)
self._has_checked_existence = True

def get_count(self, session) -> int:
# Get the count of running DAGs. A DAG is considered 'running' if
Expand Down
18 changes: 16 additions & 2 deletions catalog/dags/data_refresh/copy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,35 @@
def initialize_fdw(
upstream_conn_id: str,
downstream_conn_id: str,
media_type: str,
task: AbstractOperator = None,
):
"""Create the FDW and prepare it for copying."""
upstream_connection = Connection.get_connection_from_secrets(upstream_conn_id)
fdw_name = f"upstream_{media_type}"

run_sql.function(
postgres_conn_id=downstream_conn_id,
sql_template=queries.CREATE_FDW_QUERY,
task=task,
fdw_name=fdw_name,
host=upstream_connection.host,
port=upstream_connection.port,
dbname=upstream_connection.schema,
user=upstream_connection.login,
password=upstream_connection.password,
)

return fdw_name


@task(
max_active_tis_per_dagrun=1,
map_index_template="{{ task.op_kwargs['upstream_table_name'] }}",
)
def create_schema(downstream_conn_id: str, upstream_table_name: str) -> str:
def create_schema(
downstream_conn_id: str, upstream_table_name: str, fdw_name: str
) -> str:
"""
Create a new schema in the downstream DB through which the upstream table
can be accessed. Returns the schema name.
Expand All @@ -73,7 +80,9 @@ def create_schema(downstream_conn_id: str, upstream_table_name: str) -> str:
schema_name = f"upstream_{upstream_table_name}_schema"
downstream_pg.run(
queries.CREATE_SCHEMA_QUERY.format(
schema_name=schema_name, upstream_table_name=upstream_table_name
fdw_name=fdw_name,
schema_name=schema_name,
upstream_table_name=upstream_table_name,
)
)
return schema_name
Expand Down Expand Up @@ -183,6 +192,7 @@ def copy_data(
def copy_upstream_table(
upstream_conn_id: str,
downstream_conn_id: str,
fdw_name: str,
timeout: timedelta,
limit: int,
upstream_table_name: str,
Expand All @@ -206,6 +216,7 @@ def copy_upstream_table(
schema = create_schema(
downstream_conn_id=downstream_conn_id,
upstream_table_name=upstream_table_name,
fdw_name=fdw_name,
)

create_temp_table = run_sql.override(
Expand Down Expand Up @@ -286,6 +297,7 @@ def copy_upstream_tables(
init_fdw = initialize_fdw(
upstream_conn_id=upstream_conn_id,
downstream_conn_id=downstream_conn_id,
media_type=data_refresh_config.media_type,
)

limit = get_record_limit()
Expand All @@ -294,13 +306,15 @@ def copy_upstream_tables(
copy_tables = copy_upstream_table.partial(
upstream_conn_id=upstream_conn_id,
downstream_conn_id=downstream_conn_id,
fdw_name=init_fdw,
timeout=data_refresh_config.copy_data_timeout,
limit=limit,
).expand_kwargs([asdict(tm) for tm in data_refresh_config.table_mappings])

drop_fdw = run_sql.override(task_id="drop_fdw")(
postgres_conn_id=downstream_conn_id,
sql_template=queries.DROP_SERVER_QUERY,
fdw_name=init_fdw,
)

# Set up dependencies
Expand Down
13 changes: 3 additions & 10 deletions catalog/dags/data_refresh/create_and_populate_filtered_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,13 @@ def create_and_populate_filtered_index(
es_host: str,
media_type: MediaType,
origin_index_name: str,
filtered_index_name: str,
timeout: timedelta,
destination_index_name: str | None = None,
):
"""
Create and populate a filtered index based on the given origin index, excluding
documents with sensitive terms.
"""
filtered_index_name = get_filtered_index_name(
media_type=media_type, destination_index_name=destination_index_name
)

create_filtered_index = es.create_index.override(
trigger_rule=TriggerRule.NONE_FAILED,
)(
Expand All @@ -76,7 +72,6 @@ def create_and_populate_filtered_index(
method="GET",
response_check=lambda response: response.status_code == 200,
response_filter=response_filter_sensitive_terms_endpoint,
trigger_rule=TriggerRule.NONE_FAILED,
)

populate_filtered_index = es.trigger_and_wait_for_reindex(
Expand All @@ -99,7 +94,5 @@ def create_and_populate_filtered_index(

refresh_index = es.refresh_index(es_host=es_host, index_name=filtered_index_name)

sensitive_terms >> populate_filtered_index
create_filtered_index >> populate_filtered_index >> refresh_index

return filtered_index_name
# sensitive_terms >> populate_filtered_index
create_filtered_index >> sensitive_terms >> populate_filtered_index >> refresh_index
44 changes: 0 additions & 44 deletions catalog/dags/data_refresh/create_index.py

This file was deleted.

Loading