diff --git a/metaflow/metaflow_runner.py b/metaflow/metaflow_runner.py index bc4d8c93dd1..53db6f7b291 100644 --- a/metaflow/metaflow_runner.py +++ b/metaflow/metaflow_runner.py @@ -1,67 +1,67 @@ import os import sys -import time +import shutil +import asyncio import tempfile -import subprocess +import aiofiles from typing import Dict from metaflow import Run from metaflow.cli import start from metaflow.click_api import MetaflowAPI +from metaflow.subprocess_manager import SubprocessManager -def cli_runner(command: str, env_vars: Dict): - process = subprocess.Popen( - [sys.executable, *command.split()], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - env=env_vars, - ) - return process - - -def read_from_file_when_ready(file_pointer): - content = file_pointer.read().decode() - while not content: - time.sleep(0.1) - content = file_pointer.read().decode() - return content +async def read_from_file_when_ready(file_path): + async with aiofiles.open(file_path, "r") as file_pointer: + content = await file_pointer.read() + while not content: + await asyncio.sleep(0.1) + content = await file_pointer.read() + return content class Runner(object): def __init__( self, flow_file: str, + env: Dict = {}, **kwargs, ): self.flow_file = flow_file + self.env_vars = os.environ.copy().update(env) + self.spm = SubprocessManager(env=self.env_vars) self.api = MetaflowAPI.from_cli(self.flow_file, start) self.runner = self.api(**kwargs).run def __enter__(self): return self - def run(self, blocking: bool = False, **kwargs): - env_vars = os.environ.copy() + async def tail_logs(self, stream="stdout"): + await self.spm.get_logs(stream) + async def run(self, blocking: bool = False, **kwargs): with tempfile.TemporaryDirectory() as temp_dir: - tfp_flow_name = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False) + tfp_flow = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False) tfp_run_id = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False) command = self.runner( - run_id_file=tfp_run_id.name, flow_name_file=tfp_flow_name.name, **kwargs + run_id_file=tfp_run_id.name, flow_name_file=tfp_flow.name, **kwargs ) - process = cli_runner(command, env_vars) + process = await self.spm.run_command([sys.executable, *command.split()]) + if blocking: - process.wait() + await process.wait() - flow_name = read_from_file_when_ready(tfp_flow_name) - run_id = read_from_file_when_ready(tfp_run_id) + flow_name = await read_from_file_when_ready(tfp_flow.name) + run_id = await read_from_file_when_ready(tfp_run_id.name) pathspec_components = (flow_name, run_id) run_object = Run("/".join(pathspec_components), _namespace_check=False) + self.run = run_object + return run_object def __exit__(self, exc_type, exc_value, traceback): - pass + shutil.rmtree(self.spm.temp_dir, ignore_errors=True) diff --git a/metaflow/subprocess_manager.py b/metaflow/subprocess_manager.py new file mode 100644 index 00000000000..9b3fda029b8 --- /dev/null +++ b/metaflow/subprocess_manager.py @@ -0,0 +1,125 @@ +import os +import sys +import shutil +import asyncio +import tempfile +import aiofiles +from typing import List +from asyncio.queues import Queue + + +class SubprocessManager(object): + def __init__(self, env=None, cwd=None): + if env is None: + env = os.environ.copy() + self.env = env + + if cwd is None: + cwd = os.getcwd() + self.cwd = cwd + + self.process = None + self.run_command_called = False + + async def get_logs(self, stream="stdout"): + if self.run_command_called is False: + raise ValueError("No command run yet to get the logs for...") + if stream == "stdout": + stdout_task = asyncio.create_task(self.consume_queue(self.stdout_queue)) + await stdout_task + elif stream == "stderr": + stderr_task = asyncio.create_task(self.consume_queue(self.stderr_queue)) + await stderr_task + else: + raise ValueError( + f"Invalid value for `stream`: {stream}, valid values are: {['stdout', 'stderr']}" + ) + + async def stream_logs_to_queue(self, logfile, queue, process): + async with aiofiles.open(logfile, "r") as f: + while True: + if process.returncode is None: + # process is still running + line = await f.readline() + if not line: + continue + await queue.put(line.strip()) + elif process.returncode == 0: + # insert an indicator that no more items + # will be inserted into the queue + await queue.put(None) + break + elif process.returncode != 0: + # insert an indicator that no more items + # will be inserted into the queue + await queue.put(None) + raise Exception("Ran into an issue...") + + async def consume_queue(self, queue: Queue): + while True: + item = await queue.get() + # break out of loop when we get the `indicator` + if item is None: + break + print(item) + queue.task_done() + + async def run_command(self, command: List[str]): + self.temp_dir = tempfile.mkdtemp() + stdout_logfile = os.path.join(self.temp_dir, "stdout.log") + stderr_logfile = os.path.join(self.temp_dir, "stderr.log") + + self.stdout_queue = Queue() + self.stderr_queue = Queue() + + try: + # returns when subprocess has been started, not + # when it is finished... + self.process = await asyncio.create_subprocess_exec( + *command, + cwd=self.cwd, + env=self.env, + stdout=await aiofiles.open(stdout_logfile, "w"), + stderr=await aiofiles.open(stderr_logfile, "w"), + ) + + self.stdout_task = asyncio.create_task( + self.stream_logs_to_queue( + stdout_logfile, self.stdout_queue, self.process + ) + ) + self.stderr_task = asyncio.create_task( + self.stream_logs_to_queue( + stderr_logfile, self.stderr_queue, self.process + ) + ) + + self.run_command_called = True + return self.process + except Exception as e: + print(f"Error starting subprocess: {e}") + # Clean up temp files if process fails to start + shutil.rmtree(self.temp_dir, ignore_errors=True) + + +async def main(): + flow_file = "../try.py" + from metaflow.cli import start + from metaflow.click_api import MetaflowAPI + + api = MetaflowAPI.from_cli(flow_file, start) + command = api().run(alpha=5) + cmd = [sys.executable, *command.split()] + + spm = SubprocessManager() + process = await spm.run_command(cmd) + # await process.wait() + # print(process.returncode) + print("will print logs after 15 secs, flow has ended by then...") + await asyncio.sleep(15) + print("done waiting...") + await spm.get_logs(stream="stdout") + + +if __name__ == "__main__": + asyncio.run(main())