-
Notifications
You must be signed in to change notification settings - Fork 772
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
152 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,67 +1,67 @@ | ||
import os | ||
import sys | ||
import time | ||
import shutil | ||
import asyncio | ||
import tempfile | ||
import subprocess | ||
import aiofiles | ||
from typing import Dict | ||
from metaflow import Run | ||
from metaflow.cli import start | ||
from metaflow.click_api import MetaflowAPI | ||
from metaflow.subprocess_manager import SubprocessManager | ||
|
||
|
||
def cli_runner(command: str, env_vars: Dict): | ||
process = subprocess.Popen( | ||
[sys.executable, *command.split()], | ||
stdout=subprocess.PIPE, | ||
stderr=subprocess.PIPE, | ||
env=env_vars, | ||
) | ||
return process | ||
|
||
|
||
def read_from_file_when_ready(file_pointer): | ||
content = file_pointer.read().decode() | ||
while not content: | ||
time.sleep(0.1) | ||
content = file_pointer.read().decode() | ||
return content | ||
async def read_from_file_when_ready(file_path): | ||
async with aiofiles.open(file_path, "r") as file_pointer: | ||
content = await file_pointer.read() | ||
while not content: | ||
await asyncio.sleep(0.1) | ||
content = await file_pointer.read() | ||
return content | ||
|
||
|
||
class Runner(object): | ||
def __init__( | ||
self, | ||
flow_file: str, | ||
env: Dict = {}, | ||
**kwargs, | ||
): | ||
self.flow_file = flow_file | ||
self.env_vars = os.environ.copy().update(env) | ||
self.spm = SubprocessManager(env=self.env_vars) | ||
self.api = MetaflowAPI.from_cli(self.flow_file, start) | ||
self.runner = self.api(**kwargs).run | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def run(self, blocking: bool = False, **kwargs): | ||
env_vars = os.environ.copy() | ||
async def tail_logs(self, stream="stdout"): | ||
await self.spm.get_logs(stream) | ||
|
||
async def run(self, blocking: bool = False, **kwargs): | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
tfp_flow_name = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False) | ||
tfp_flow = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False) | ||
tfp_run_id = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False) | ||
|
||
command = self.runner( | ||
run_id_file=tfp_run_id.name, flow_name_file=tfp_flow_name.name, **kwargs | ||
run_id_file=tfp_run_id.name, flow_name_file=tfp_flow.name, **kwargs | ||
) | ||
|
||
process = cli_runner(command, env_vars) | ||
process = await self.spm.run_command([sys.executable, *command.split()]) | ||
|
||
if blocking: | ||
process.wait() | ||
await process.wait() | ||
|
||
flow_name = read_from_file_when_ready(tfp_flow_name) | ||
run_id = read_from_file_when_ready(tfp_run_id) | ||
flow_name = await read_from_file_when_ready(tfp_flow.name) | ||
run_id = await read_from_file_when_ready(tfp_run_id.name) | ||
|
||
pathspec_components = (flow_name, run_id) | ||
run_object = Run("/".join(pathspec_components), _namespace_check=False) | ||
|
||
self.run = run_object | ||
|
||
return run_object | ||
|
||
def __exit__(self, exc_type, exc_value, traceback): | ||
pass | ||
shutil.rmtree(self.spm.temp_dir, ignore_errors=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import os | ||
import sys | ||
import shutil | ||
import asyncio | ||
import tempfile | ||
import aiofiles | ||
from typing import List | ||
from asyncio.queues import Queue | ||
|
||
|
||
class SubprocessManager(object): | ||
def __init__(self, env=None, cwd=None): | ||
if env is None: | ||
env = os.environ.copy() | ||
self.env = env | ||
|
||
if cwd is None: | ||
cwd = os.getcwd() | ||
self.cwd = cwd | ||
|
||
self.process = None | ||
self.run_command_called = False | ||
|
||
async def get_logs(self, stream="stdout"): | ||
if self.run_command_called is False: | ||
raise ValueError("No command run yet to get the logs for...") | ||
if stream == "stdout": | ||
stdout_task = asyncio.create_task(self.consume_queue(self.stdout_queue)) | ||
await stdout_task | ||
elif stream == "stderr": | ||
stderr_task = asyncio.create_task(self.consume_queue(self.stderr_queue)) | ||
await stderr_task | ||
else: | ||
raise ValueError( | ||
f"Invalid value for `stream`: {stream}, valid values are: {['stdout', 'stderr']}" | ||
) | ||
|
||
async def stream_logs_to_queue(self, logfile, queue, process): | ||
async with aiofiles.open(logfile, "r") as f: | ||
while True: | ||
if process.returncode is None: | ||
# process is still running | ||
line = await f.readline() | ||
if not line: | ||
continue | ||
await queue.put(line.strip()) | ||
elif process.returncode == 0: | ||
# insert an indicator that no more items | ||
# will be inserted into the queue | ||
await queue.put(None) | ||
break | ||
elif process.returncode != 0: | ||
# insert an indicator that no more items | ||
# will be inserted into the queue | ||
await queue.put(None) | ||
raise Exception("Ran into an issue...") | ||
|
||
async def consume_queue(self, queue: Queue): | ||
while True: | ||
item = await queue.get() | ||
# break out of loop when we get the `indicator` | ||
if item is None: | ||
break | ||
print(item) | ||
queue.task_done() | ||
|
||
async def run_command(self, command: List[str]): | ||
self.temp_dir = tempfile.mkdtemp() | ||
stdout_logfile = os.path.join(self.temp_dir, "stdout.log") | ||
stderr_logfile = os.path.join(self.temp_dir, "stderr.log") | ||
|
||
self.stdout_queue = Queue() | ||
self.stderr_queue = Queue() | ||
|
||
try: | ||
# returns when subprocess has been started, not | ||
# when it is finished... | ||
self.process = await asyncio.create_subprocess_exec( | ||
*command, | ||
cwd=self.cwd, | ||
env=self.env, | ||
stdout=await aiofiles.open(stdout_logfile, "w"), | ||
stderr=await aiofiles.open(stderr_logfile, "w"), | ||
) | ||
|
||
self.stdout_task = asyncio.create_task( | ||
self.stream_logs_to_queue( | ||
stdout_logfile, self.stdout_queue, self.process | ||
) | ||
) | ||
self.stderr_task = asyncio.create_task( | ||
self.stream_logs_to_queue( | ||
stderr_logfile, self.stderr_queue, self.process | ||
) | ||
) | ||
|
||
self.run_command_called = True | ||
return self.process | ||
except Exception as e: | ||
print(f"Error starting subprocess: {e}") | ||
# Clean up temp files if process fails to start | ||
shutil.rmtree(self.temp_dir, ignore_errors=True) | ||
|
||
|
||
async def main(): | ||
flow_file = "../try.py" | ||
from metaflow.cli import start | ||
from metaflow.click_api import MetaflowAPI | ||
|
||
api = MetaflowAPI.from_cli(flow_file, start) | ||
command = api().run(alpha=5) | ||
cmd = [sys.executable, *command.split()] | ||
|
||
spm = SubprocessManager() | ||
process = await spm.run_command(cmd) | ||
# await process.wait() | ||
# print(process.returncode) | ||
print("will print logs after 15 secs, flow has ended by then...") | ||
await asyncio.sleep(15) | ||
print("done waiting...") | ||
await spm.get_logs(stream="stdout") | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |