Skip to content

Commit

Permalink
subprocess manager
Browse files Browse the repository at this point in the history
  • Loading branch information
madhur-ob committed Mar 12, 2024
1 parent a3919a0 commit 6f725aa
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 27 deletions.
54 changes: 27 additions & 27 deletions metaflow/metaflow_runner.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,67 @@
import os
import sys
import time
import shutil
import asyncio
import tempfile
import subprocess
import aiofiles
from typing import Dict
from metaflow import Run
from metaflow.cli import start
from metaflow.click_api import MetaflowAPI
from metaflow.subprocess_manager import SubprocessManager


def cli_runner(command: str, env_vars: Dict):
process = subprocess.Popen(
[sys.executable, *command.split()],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env_vars,
)
return process


def read_from_file_when_ready(file_pointer):
content = file_pointer.read().decode()
while not content:
time.sleep(0.1)
content = file_pointer.read().decode()
return content
async def read_from_file_when_ready(file_path):
async with aiofiles.open(file_path, "r") as file_pointer:
content = await file_pointer.read()
while not content:
await asyncio.sleep(0.1)
content = await file_pointer.read()
return content


class Runner(object):
def __init__(
self,
flow_file: str,
env: Dict = {},
**kwargs,
):
self.flow_file = flow_file
self.env_vars = os.environ.copy().update(env)
self.spm = SubprocessManager(env=self.env_vars)
self.api = MetaflowAPI.from_cli(self.flow_file, start)
self.runner = self.api(**kwargs).run

def __enter__(self):
return self

def run(self, blocking: bool = False, **kwargs):
env_vars = os.environ.copy()
async def tail_logs(self, stream="stdout"):
await self.spm.get_logs(stream)

async def run(self, blocking: bool = False, **kwargs):
with tempfile.TemporaryDirectory() as temp_dir:
tfp_flow_name = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
tfp_flow = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
tfp_run_id = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)

command = self.runner(
run_id_file=tfp_run_id.name, flow_name_file=tfp_flow_name.name, **kwargs
run_id_file=tfp_run_id.name, flow_name_file=tfp_flow.name, **kwargs
)

process = cli_runner(command, env_vars)
process = await self.spm.run_command([sys.executable, *command.split()])

if blocking:
process.wait()
await process.wait()

flow_name = read_from_file_when_ready(tfp_flow_name)
run_id = read_from_file_when_ready(tfp_run_id)
flow_name = await read_from_file_when_ready(tfp_flow.name)
run_id = await read_from_file_when_ready(tfp_run_id.name)

pathspec_components = (flow_name, run_id)
run_object = Run("/".join(pathspec_components), _namespace_check=False)

self.run = run_object

return run_object

def __exit__(self, exc_type, exc_value, traceback):
pass
shutil.rmtree(self.spm.temp_dir, ignore_errors=True)
125 changes: 125 additions & 0 deletions metaflow/subprocess_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os
import sys
import shutil
import asyncio
import tempfile
import aiofiles
from typing import List
from asyncio.queues import Queue


class SubprocessManager(object):
def __init__(self, env=None, cwd=None):
if env is None:
env = os.environ.copy()
self.env = env

if cwd is None:
cwd = os.getcwd()
self.cwd = cwd

self.process = None
self.run_command_called = False

async def get_logs(self, stream="stdout"):
if self.run_command_called is False:
raise ValueError("No command run yet to get the logs for...")
if stream == "stdout":
stdout_task = asyncio.create_task(self.consume_queue(self.stdout_queue))
await stdout_task
elif stream == "stderr":
stderr_task = asyncio.create_task(self.consume_queue(self.stderr_queue))
await stderr_task
else:
raise ValueError(
f"Invalid value for `stream`: {stream}, valid values are: {['stdout', 'stderr']}"
)

async def stream_logs_to_queue(self, logfile, queue, process):
async with aiofiles.open(logfile, "r") as f:
while True:
if process.returncode is None:
# process is still running
line = await f.readline()
if not line:
continue
await queue.put(line.strip())
elif process.returncode == 0:
# insert an indicator that no more items
# will be inserted into the queue
await queue.put(None)
break
elif process.returncode != 0:
# insert an indicator that no more items
# will be inserted into the queue
await queue.put(None)
raise Exception("Ran into an issue...")

async def consume_queue(self, queue: Queue):
while True:
item = await queue.get()
# break out of loop when we get the `indicator`
if item is None:
break
print(item)
queue.task_done()

async def run_command(self, command: List[str]):
self.temp_dir = tempfile.mkdtemp()
stdout_logfile = os.path.join(self.temp_dir, "stdout.log")
stderr_logfile = os.path.join(self.temp_dir, "stderr.log")

self.stdout_queue = Queue()
self.stderr_queue = Queue()

try:
# returns when subprocess has been started, not
# when it is finished...
self.process = await asyncio.create_subprocess_exec(
*command,
cwd=self.cwd,
env=self.env,
stdout=await aiofiles.open(stdout_logfile, "w"),
stderr=await aiofiles.open(stderr_logfile, "w"),
)

self.stdout_task = asyncio.create_task(
self.stream_logs_to_queue(
stdout_logfile, self.stdout_queue, self.process
)
)
self.stderr_task = asyncio.create_task(
self.stream_logs_to_queue(
stderr_logfile, self.stderr_queue, self.process
)
)

self.run_command_called = True
return self.process
except Exception as e:
print(f"Error starting subprocess: {e}")
# Clean up temp files if process fails to start
shutil.rmtree(self.temp_dir, ignore_errors=True)


async def main():
flow_file = "../try.py"
from metaflow.cli import start
from metaflow.click_api import MetaflowAPI

api = MetaflowAPI.from_cli(flow_file, start)
command = api().run(alpha=5)
cmd = [sys.executable, *command.split()]

spm = SubprocessManager()
process = await spm.run_command(cmd)
# await process.wait()
# print(process.returncode)
print("will print logs after 15 secs, flow has ended by then...")
await asyncio.sleep(15)
print("done waiting...")
await spm.get_logs(stream="stdout")


if __name__ == "__main__":
asyncio.run(main())

0 comments on commit 6f725aa

Please sign in to comment.