Skip to content

Commit

Permalink
simpler subprocess manager
Browse files Browse the repository at this point in the history
  • Loading branch information
madhur-ob committed Mar 28, 2024
1 parent 63acedd commit bafe765
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 67 deletions.
1 change: 0 additions & 1 deletion metaflow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,6 @@ def write_file(file_path, content):
if file_path is not None:
with open(file_path, "w") as f:
f.write(str(content))
f.close()


def before_run(obj, tags, decospecs):
Expand Down
135 changes: 69 additions & 66 deletions metaflow/subprocess_manager.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import os
import sys
import time
import shutil
import asyncio
import tempfile
import aiofiles
from typing import List
from asyncio.queues import Queue


class SubprocessManager(object):
Expand All @@ -20,87 +19,77 @@ def __init__(self, env=None, cwd=None):

self.process = None
self.run_command_called = False

async def get_logs(self, stream="stdout"):
if self.run_command_called is False:
raise ValueError("No command run yet to get the logs for...")
if stream == "stdout":
stdout_task = asyncio.create_task(self.consume_queue(self.stdout_queue))
await stdout_task
elif stream == "stderr":
stderr_task = asyncio.create_task(self.consume_queue(self.stderr_queue))
await stderr_task
else:
raise ValueError(
f"Invalid value for `stream`: {stream}, valid values are: {['stdout', 'stderr']}"
)

async def stream_logs_to_queue(self, logfile, queue, process):
async with aiofiles.open(logfile, "r") as f:
while True:
if process.returncode is None:
# process is still running
line = await f.readline()
if not line:
continue
await queue.put(line.strip())
elif process.returncode == 0:
# insert an indicator that no more items
# will be inserted into the queue
await queue.put(None)
break
elif process.returncode != 0:
# insert an indicator that no more items
# will be inserted into the queue
await queue.put(None)
raise Exception("Ran into an issue...")

async def consume_queue(self, queue: Queue):
while True:
item = await queue.get()
# break out of loop when we get the `indicator`
if item is None:
break
print(item)
queue.task_done()
self.log_files = {}
self.process_dict = {}

async def run_command(self, command: List[str]):
self.temp_dir = tempfile.mkdtemp()
stdout_logfile = os.path.join(self.temp_dir, "stdout.log")
stderr_logfile = os.path.join(self.temp_dir, "stderr.log")

self.stdout_queue = Queue()
self.stderr_queue = Queue()

try:
# returns when subprocess has been started, not
# when it is finished...
self.process = await asyncio.create_subprocess_exec(
*command,
cwd=self.cwd,
env=self.env,
stdout=await aiofiles.open(stdout_logfile, "w"),
stderr=await aiofiles.open(stderr_logfile, "w"),
stdout=open(stdout_logfile, "w"),
stderr=open(stderr_logfile, "w"),
)

self.stdout_task = asyncio.create_task(
self.stream_logs_to_queue(
stdout_logfile, self.stdout_queue, self.process
)
)
self.stderr_task = asyncio.create_task(
self.stream_logs_to_queue(
stderr_logfile, self.stderr_queue, self.process
)
)
self.log_files["stdout"] = stdout_logfile
self.log_files["stderr"] = stderr_logfile

self.run_command_called = True
return self.process
except Exception as e:
print(f"Error starting subprocess: {e}")
# Clean up temp files if process fails to start
await self.cleanup()

async def stream_logs(self, stream):
if self.run_command_called is False:
raise ValueError("No command run yet to get the logs for...")

if stream not in self.log_files:
raise ValueError(f"No log file found for {stream}")

log_file = self.log_files[stream]

with open(log_file, mode="r") as f:
last_position = self.process_dict.get(stream, 0)
f.seek(last_position)

while True:
line = f.readline()
if not line:
break
print(line.strip())

self.process_dict[stream] = f.tell()

async def get_logs(self, stream="stdout", delay=0.1):
while self.process.returncode is None:
await self.stream_logs(stream)
await asyncio.sleep(delay)

async def cleanup(self):
if hasattr(self, "temp_dir"):
shutil.rmtree(self.temp_dir, ignore_errors=True)

async def kill_process(self, timeout=5):
if self.process is not None:
if self.process.returncode is None:
self.process.terminate()
try:
await asyncio.wait_for(self.process.wait(), timeout)
except asyncio.TimeoutError:
self.process.kill()
else:
print("Process has already terminated.")
else:
print("No process to kill.")


async def main():
flow_file = "../try.py"
Expand All @@ -114,11 +103,25 @@ async def main():
spm = SubprocessManager()
process = await spm.run_command(cmd)
# await process.wait()
# print(process.returncode)
print("will print logs after 15 secs, flow has ended by then...")
await asyncio.sleep(15)
print("done waiting...")
await spm.get_logs(stream="stdout")
print(process.returncode)
print(process)

# print("kill after 2 seconds...get logs upto the point of killing...")
# await asyncio.wait([
# asyncio.create_task(spm.get_logs(stream="stdout")),
# asyncio.create_task(asyncio.sleep(2)),
# ], return_when="FIRST_COMPLETED")
# await spm.kill_process()
# print("done...")

# print("will print logs after 15 secs, flow has ended by then...")
# time.sleep(15)
# print("done waiting...")
# await spm.get_logs(stream="stdout")
# await spm.cleanup()

# await spm.get_logs(stream="stdout")
# await spm.cleanup()


if __name__ == "__main__":
Expand Down

0 comments on commit bafe765

Please sign in to comment.