From c681f71336f196e6a48d090001c945fe9fbacf25 Mon Sep 17 00:00:00 2001 From: Madhur Tandon Date: Wed, 17 Apr 2024 03:00:26 +0530 Subject: [PATCH] suggested improvements --- metaflow/subprocess_manager.py | 120 ++++++++++++++++++++------------- 1 file changed, 75 insertions(+), 45 deletions(-) diff --git a/metaflow/subprocess_manager.py b/metaflow/subprocess_manager.py index b18ab9dca7a..6193131f0b7 100644 --- a/metaflow/subprocess_manager.py +++ b/metaflow/subprocess_manager.py @@ -1,94 +1,112 @@ import os import sys -import time import signal import shutil -import hashlib import asyncio import tempfile -from typing import List - - -def hash_command_invocation(command: List[str]): - concatenated_string = "".join(command) - current_time = str(time.time()) - concatenated_string += current_time - hash_object = hashlib.sha256(concatenated_string.encode()) - return hash_object.hexdigest() +from typing import List, Dict, Optional, Callable class LogReadTimeoutError(Exception): + """Exception raised when reading logs times out.""" + pass class SubprocessManager(object): + """A manager for subprocesses.""" + def __init__(self): - self.commands = {} + self.commands: Dict[int, CommandManager] = {} - async def __aenter__(self): + async def __aenter__(self) -> "SubprocessManager": return self async def __aexit__(self, exc_type, exc_value, traceback): await self.cleanup() - async def run_command(self, command: List[str], env=None, cwd=None): - command_id = hash_command_invocation(command) - self.commands[command_id] = CommandManager(command, env, cwd) - await self.commands[command_id].run() - return command_id + async def run_command( + self, + command: List[str], + env: Optional[Dict[str, str]] = None, + cwd: Optional[str] = None, + ) -> int: + """Run a command asynchronously and return its process ID.""" - def get(self, command_id: str) -> "CommandManager": - return self.commands.get(command_id, None) + command_obj = CommandManager(command, env, cwd) + process = await command_obj.run() + self.commands[process.pid] = command_obj + return process.pid - async def cleanup(self): - for _, v in self.commands.items(): + def get(self, pid: int) -> "CommandManager": + """Get the CommandManager object for a given process ID.""" + + return self.commands.get(pid, None) + + async def cleanup(self) -> None: + """Clean up log files for all running subprocesses.""" + + for v in self.commands.values(): await v.cleanup() class CommandManager(object): - def __init__(self, command: List[str], env=None, cwd=None): - self.command = command + """A manager for an individual subprocess.""" - if env is None: - env = os.environ.copy() - self.env = env + def __init__( + self, + command: List[str], + env: Optional[Dict[str, str]] = None, + cwd: Optional[str] = None, + ): + self.command = command - if cwd is None: - cwd = os.getcwd() - self.cwd = cwd + self.env = env if env is not None else os.environ.copy() + self.cwd = cwd if cwd is not None else os.getcwd() self.process = None - self.run_called = False - self.log_files = {} + self.run_called: bool = False + self.log_files: Dict[str, str] = {} signal.signal(signal.SIGINT, self.handle_sigint) - async def __aenter__(self): + async def __aenter__(self) -> "CommandManager": return self async def __aexit__(self, exc_type, exc_value, traceback): await self.cleanup() def handle_sigint(self, signum, frame): + """Handle the SIGINT signal.""" + print("SIGINT received.") asyncio.create_task(self.kill()) - async def wait(self, timeout=None, stream=None): + async def wait( + self, timeout: Optional[float] = None, stream: Optional[str] = None + ) -> None: + """Wait for the subprocess to finish, optionally with a timeout and optionally streaming its output.""" + if timeout is None: if stream is None: await self.process.wait() else: await self.emit_logs(stream) else: - tasks = [asyncio.create_task(asyncio.sleep(timeout))] - if stream is None: - tasks.append(asyncio.create_task(self.process.wait())) - else: - tasks.append(asyncio.create_task(self.emit_logs(stream))) - - await asyncio.wait(tasks, return_when="FIRST_COMPLETED") + try: + if stream is None: + await asyncio.wait_for(self.process.wait(), timeout) + else: + await asyncio.wait_for(self.emit_logs(stream), timeout) + except asyncio.TimeoutError: + command_string = " ".join(self.command) + print( + f"Timeout: The process: '{command_string}' didn't complete within {timeout} seconds." + ) async def run(self): + """Run the subprocess, streaming the logs to temporary files""" + self.temp_dir = tempfile.mkdtemp() stdout_logfile = os.path.join(self.temp_dir, "stdout.log") stderr_logfile = os.path.join(self.temp_dir, "stderr.log") @@ -114,14 +132,20 @@ async def run(self): await self.cleanup() async def stream_logs( - self, stream, position=None, timeout_per_line=None, log_write_delay=0.01 + self, + stream: str, + position: Optional[int] = None, + timeout_per_line: Optional[float] = None, + log_write_delay: float = 0.01, ): + """Stream logs from the subprocess using the log files""" + if self.run_called is False: raise ValueError("No command run yet to get the logs for...") if stream not in self.log_files: raise ValueError( - f"No log file found for {stream}, valid values are: {list(self.log_files.keys())}" + f"No log file found for '{stream}', valid values are: {list(self.log_files.keys())}" ) log_file = self.log_files[stream] @@ -161,15 +185,21 @@ async def stream_logs( position = f.tell() yield position, line.strip() - async def emit_logs(self, stream="stdout", custom_logger=print): + async def emit_logs(self, stream: str = "stdout", custom_logger: Callable = print): + """Helper function to iterate over stream_logs""" + async for _, line in self.stream_logs(stream): custom_logger(line) async def cleanup(self): + """Clean up log files for a running subprocesses.""" + if hasattr(self, "temp_dir"): shutil.rmtree(self.temp_dir, ignore_errors=True) - async def kill(self, termination_timeout=5): + async def kill(self, termination_timeout: float = 5): + """Kill the subprocess.""" + if self.process is not None: if self.process.returncode is None: self.process.terminate()