Skip to content

Commit

Permalink
Add new search scheduler implementation for improved simplicity and r…
Browse files Browse the repository at this point in the history
…obustness. (#240)
  • Loading branch information
gibber9809 authored Jan 25, 2024
1 parent 5d6ff54 commit 284a558
Show file tree
Hide file tree
Showing 25 changed files with 787 additions and 366 deletions.
128 changes: 16 additions & 112 deletions components/clp-package-utils/clp_package_utils/scripts/native/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,29 @@

import argparse
import asyncio
import datetime
import logging
import multiprocessing
import pathlib
import socket
import sys
import time
from contextlib import closing

import msgpack
import pymongo
import zstandard

from clp_package_utils.general import (
CLP_DEFAULT_CONFIG_FILE_RELATIVE_PATH,
validate_and_load_config_file,
get_clp_home
)
from clp_py_utils.clp_config import (
CLP_METADATA_TABLE_PREFIX,
SEARCH_JOBS_TABLE_NAME,
Database,
ResultsCache
)
from clp_py_utils.sql_adapter import SQL_Adapter
from job_orchestration.job_config import SearchConfig
from job_orchestration.scheduler.constants import JobStatus
from job_orchestration.search_scheduler.common import JobStatus

# Setup logging
# Create logger
Expand Down Expand Up @@ -75,81 +72,31 @@ def process_error_callback(err):

def create_and_monitor_job_in_db(db_config: Database, results_cache: ResultsCache,
wildcard_query: str, begin_timestamp: int | None,
end_timestamp: int | None, path_filter: str,
search_controller_host: str, search_controller_port: int):
end_timestamp: int | None, path_filter: str):
search_config = SearchConfig(
search_controller_host=search_controller_host,
search_controller_port=search_controller_port,
wildcard_query=wildcard_query,
query_string=wildcard_query,
begin_timestamp=begin_timestamp,
end_timestamp=end_timestamp,
path_filter=path_filter
)

sql_adapter = SQL_Adapter(db_config)
zstd_cctx = zstandard.ZstdCompressor(level=3)
with closing(sql_adapter.create_connection(True)) as \
db_conn, closing(db_conn.cursor(dictionary=True)) as db_cursor:
# Create job
db_cursor.execute(f"INSERT INTO `search_jobs` (`search_config`) VALUES (%s)",
(zstd_cctx.compress(msgpack.packb(search_config.dict())),))
db_cursor.execute(f"INSERT INTO `{SEARCH_JOBS_TABLE_NAME}` (`search_config`) VALUES (%s)",
(msgpack.packb(search_config.dict()),))
db_conn.commit()
job_id = db_cursor.lastrowid

next_pagination_id = 0
pagination_limit = 64
num_tasks_added = 0
query_base_conditions = []
if begin_timestamp is not None:
query_base_conditions.append(f"`end_timestamp` >= {begin_timestamp}")
if end_timestamp is not None:
query_base_conditions.append(f"`begin_timestamp` <= {end_timestamp}")
while True:
# Get next `limit` rows
query_conditions = query_base_conditions + [f"`pagination_id` >= {next_pagination_id}"]
query = f"""
SELECT `id` FROM {CLP_METADATA_TABLE_PREFIX}archives
WHERE {" AND ".join(query_conditions)}
LIMIT {pagination_limit}
"""
db_cursor.execute(query)
rows = db_cursor.fetchall()
if len(rows) == 0:
break

# Insert tasks
db_cursor.execute(f"""
INSERT INTO `search_tasks` (`job_id`, `archive_id`, `scheduled_time`)
VALUES ({"), (".join(f"{job_id}, '{row['id']}', '{datetime.datetime.utcnow()}'" for row in rows)})
""")
db_conn.commit()
num_tasks_added += len(rows)

if len(rows) < pagination_limit:
# Less than limit rows returned, so there are no more rows
break
next_pagination_id += pagination_limit

# Mark job as scheduled
db_cursor.execute(f"""
UPDATE `search_jobs`
SET num_tasks={num_tasks_added}, status = '{JobStatus.SCHEDULED}'
WHERE id = {job_id}
""")
db_conn.commit()

# Wait for the job to be marked complete
job_complete = False
while not job_complete:
db_cursor.execute(f"SELECT `status`, `status_msg` FROM `search_jobs` WHERE `id` = {job_id}")
while True:
db_cursor.execute(f"SELECT `status` FROM `{SEARCH_JOBS_TABLE_NAME}` WHERE `id` = {job_id}")
# There will only ever be one row since it's impossible to have more than one job with the same ID
row = db_cursor.fetchall()[0]
if JobStatus.SUCCEEDED == row['status']:
job_complete = True
elif JobStatus.FAILED == row['status']:
logger.error(row['status_msg'])
job_complete = True
new_status = db_cursor.fetchall()[0]['status']
db_conn.commit()
if new_status in (JobStatus.SUCCESS, JobStatus.FAILED, JobStatus.CANCELLED):
break

time.sleep(0.5)

Expand All @@ -159,50 +106,17 @@ def create_and_monitor_job_in_db(db_config: Database, results_cache: ResultsCach
print(f"{document['original_path']}: {document['message']}", end='')


async def worker_connection_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
try:
buf = await reader.read(1024)
if b'' == buf:
# Worker closed
return
except asyncio.CancelledError:
return
finally:
writer.close()


async def do_search(db_config: Database, results_cache: ResultsCache, wildcard_query: str,
begin_timestamp: int | None, end_timestamp: int | None, path_filter: str, host: str):
# Start a server
try:
server = await asyncio.start_server(client_connected_cb=worker_connection_handler, host=host, port=0,
family=socket.AF_INET)
except asyncio.CancelledError:
# Search cancelled
return
port = server.sockets[0].getsockname()[1]

server_task = asyncio.ensure_future(server.serve_forever())

begin_timestamp: int | None, end_timestamp: int | None, path_filter: str):
db_monitor_task = asyncio.ensure_future(
run_function_in_process(create_and_monitor_job_in_db, db_config, results_cache, wildcard_query,
begin_timestamp, end_timestamp, path_filter, host, port))
begin_timestamp, end_timestamp, path_filter))

# Wait for the job to complete or an error to occur
pending = [server_task, db_monitor_task]
try:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
if db_monitor_task in done:
server.close()
await server.wait_closed()
else:
logger.error("server task unexpectedly returned")
db_monitor_task.cancel()
await db_monitor_task
except asyncio.CancelledError:
server.close()
await server.wait_closed()
await db_monitor_task
except asyncio.CancelledError:
pass


def main(argv):
Expand Down Expand Up @@ -237,18 +151,8 @@ def main(argv):
logger.exception("Failed to load config.")
return -1

# Get IP of local machine
host_ip = None
for ip in set(socket.gethostbyname_ex(socket.gethostname())[2]):
host_ip = ip
break
if host_ip is None:
logger.error("Could not determine IP of local machine.")
return -1

asyncio.run(do_search(clp_config.database, clp_config.results_cache, parsed_args.wildcard_query,
parsed_args.begin_time, parsed_args.end_time, parsed_args.file_path,
host_ip))
parsed_args.begin_time, parsed_args.end_time, parsed_args.file_path))

return 0

Expand Down
Loading

0 comments on commit 284a558

Please sign in to comment.