Skip to content

Commit

Permalink
Improve pipeline stop logic to ensure join is called exactly once for…
Browse files Browse the repository at this point in the history
… all stages (nv-morpheus#1479)

1. Removes the `_is_built`, `_is_started` and `_is_stopped` flags and replaces with single member which holds onto the state enum for: INITIALIZED, BUILT, STARTED, STOPPED, COMPLETED
1. Changes the meaning of `stop()` and the meaning of `join()` for stages
   1. `stop()` called 0 or 1 times. Only way it can get called is if `pipeline.stop()` was called indicating the pipeline should try to shut down gracefully. 
      1. Users should only implement this method if they have a source stage (or sources in their stage)
   1. `join()` called exactly 1 time. Only called when the pipeline is complete and all stages are shut down. This is where users should implement any cleanup code
1. Tests for handling all of these scenarios with the pipeline.

Closes nv-morpheus#1477

## By Submitting this PR I confirm:
- I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md).
- When the PR is ready for review, new or existing tests cover these changes.
- When the PR is ready for review, the documentation is up to date with these changes.

Authors:
  - Eli Fajardo (https://github.com/efajardo-nv)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: nv-morpheus#1479
  • Loading branch information
efajardo-nv authored Feb 15, 2024
1 parent 9cd8863 commit 5fd661b
Show file tree
Hide file tree
Showing 7 changed files with 389 additions and 69 deletions.
124 changes: 65 additions & 59 deletions morpheus/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
import os
import signal
import sys
import threading
import typing
from collections import OrderedDict
from collections import defaultdict
from enum import Enum
from functools import partial

import mrc
Expand All @@ -41,6 +43,14 @@
StageT = typing.TypeVar("StageT", bound=StageBase)


class PipelineState(Enum):
INITIALIZED = "initialized"
BUILT = "built"
STARTED = "started"
STOPPED = "stopped"
COMPLETED = "completed"


class Pipeline():
"""
Class for building your pipeline. A pipeline for your use case can be constructed by first adding a
Expand All @@ -56,16 +66,19 @@ class Pipeline():
"""

def __init__(self, config: Config):

self._mutex = threading.RLock()

self._source_count: int = None # Maximum number of iterations for progress reporting. None = Unknown/Unlimited

self._id_counter = 0
self._num_threads = config.num_threads

# Complete set of nodes across segments in this pipeline
self._stages: typing.Set[Stage] = set()
self._stages: typing.List[Stage] = []

# Complete set of sources across segments in this pipeline
self._sources: typing.Set[SourceStage] = set()
self._sources: typing.List[SourceStage] = []

# Dictionary containing segment information for this pipeline
self._segments: typing.Dict = defaultdict(lambda: {"nodes": set(), "ingress_ports": [], "egress_ports": []})
Expand All @@ -75,19 +88,21 @@ def __init__(self, config: Config):

self._segment_graphs = defaultdict(lambda: networkx.DiGraph())

self._is_built = False
self._is_started = False
self._state = PipelineState.INITIALIZED

self._mrc_executor: mrc.Executor = None

self._loop: asyncio.AbstractEventLoop = None

# Future that allows post_start to propagate exceptions back to pipeline
self._post_start_future: asyncio.Future = None

@property
def is_built(self) -> bool:
return self._is_built
def state(self) -> PipelineState:
return self._state

def _assert_not_built(self):
assert not self.is_built, "Pipeline has already been built. Cannot modify pipeline."
assert self._state == PipelineState.INITIALIZED, "Pipeline has already been built. Cannot modify pipeline."

def add_stage(self, stage: StageT, segment_id: str = "main") -> StageT:
"""
Expand All @@ -110,10 +125,10 @@ def add_stage(self, stage: StageT, segment_id: str = "main") -> StageT:
# Add to list of stages if it's a stage, not a source
if (isinstance(stage, Stage)):
segment_nodes.add(stage)
self._stages.add(stage)
self._stages.append(stage)
elif (isinstance(stage, SourceStage)):
segment_nodes.add(stage)
self._sources.add(stage)
self._sources.append(stage)
else:
raise NotImplementedError(f"add_stage() failed. Unknown node type: {type(stage)}")

Expand Down Expand Up @@ -279,7 +294,7 @@ def build(self):
Once the pipeline has been constructed, this will start the pipeline by calling `Source.start` on the source
object.
"""
assert not self._is_built, "Pipeline can only be built once!"
assert self._state == PipelineState.INITIALIZED, "Pipeline can only be built once!"
assert len(self._sources) > 0, "Pipeline must have a source stage"

self._pre_build()
Expand Down Expand Up @@ -341,19 +356,16 @@ def inner_build(builder: mrc.Builder, segment_id: str):

self._mrc_executor.register_pipeline(mrc_pipeline)

self._is_built = True
with self._mutex:
self._state = PipelineState.BUILT

logger.info("====Registering Pipeline Complete!====")

async def _start(self):
assert self._is_built, "Pipeline must be built before starting"

# Only execute this once
if (self._is_started):
return
assert self._state == PipelineState.BUILT, "Pipeline must be built before starting"

# Stop from running this twice
self._is_started = True
with self._mutex:
self._state = PipelineState.STARTED

# Save off the current loop so we can use it in async_start
self._loop = asyncio.get_running_loop()
Expand Down Expand Up @@ -392,63 +404,62 @@ def term_signal():

logger.info("====Pipeline Started====")

async def post_start(executor):

try:
# Make a local reference so the object doesn't go out of scope from a call to stop()
await executor.join_async()
except Exception:
logger.exception("Exception occurred in pipeline. Rethrowing")
raise
finally:
# Call join on all sources. This only occurs after all messages have been processed fully.
for source in list(self._sources):
await source.join()

# Now call join on all stages
for stage in list(self._stages):
await stage.join()

self._on_stop()

with self._mutex:
self._state = PipelineState.COMPLETED

self._post_start_future = asyncio.create_task(post_start(self._mrc_executor))

def stop(self):
"""
Stops all running stages and the underlying MRC pipeline.
"""
assert self._state == PipelineState.STARTED, "Pipeline must be running to stop it"

logger.info("====Stopping Pipeline====")
for stage in list(self._sources) + list(self._stages):
stage.stop()

self._mrc_executor.stop()

with self._mutex:
self._state = PipelineState.STOPPED

logger.info("====Pipeline Stopped====")
self._on_stop()

async def join(self):
"""
Suspend execution all currently running stages and the MRC pipeline.
Typically called after `stop`.
Wait until pipeline completes upon which join methods of sources and stages will be called.
"""
try:
# If the pipeline failed any pre-flight checks self._mrc_executor will be None
if self._mrc_executor is None:
raise RuntimeError("Pipeline failed pre-flight checks.")

# Make a local reference so the object doesnt go out of scope from a call to stop()
executor = self._mrc_executor

await executor.join_async()
except Exception:
logger.exception("Exception occurred in pipeline. Rethrowing")
raise
finally:
# Make sure these are always shut down even if there was an error
for source in list(self._sources):
source.stop()
assert self._post_start_future is not None, "Pipeline must be started before joining"

# First wait for all sources to stop. This only occurs after all messages have been processed fully
for source in list(self._sources):
await source.join()

# Now that there is no more data, call stop on all stages to ensure shutdown (i.e., for stages that have
# their own worker loop thread)
for stage in list(self._stages):
stage.stop()

# Now call join on all stages
for stage in list(self._stages):
await stage.join()

self._on_stop()
await self._post_start_future

def _on_stop(self):
self._mrc_executor = None

async def _build_and_start(self):
async def build_and_start(self):

if (not self.is_built):
if (self._state == PipelineState.INITIALIZED):
try:
self.build()
except Exception:
Expand All @@ -470,7 +481,7 @@ def visualize(self, filename: str = None, **graph_kwargs):
exists it will be overwritten. Requires the graphviz library.
"""

if not self._is_built:
if self._state == PipelineState.INITIALIZED:
raise RuntimeError("Pipeline.visualize() requires that the Pipeline has been started before generating "
"the visualization. Please call Pipeline.build() or Pipeline.run() before calling "
"Pipeline.visualize().")
Expand Down Expand Up @@ -624,9 +635,7 @@ async def run_async(self):
This function sets up the current asyncio loop, builds the pipeline, and awaits on it to complete.
"""
try:
await self._build_and_start()

# Wait for completion
await self.build_and_start()
await self.join()

except KeyboardInterrupt:
Expand All @@ -635,9 +644,6 @@ async def run_async(self):
# Stop the pipeline
self.stop()

# Wait again for nice completion
await self.join()

finally:
# Shutdown the async generator sources and exit
logger.info("====Pipeline Complete====")
Expand Down
2 changes: 1 addition & 1 deletion morpheus/stages/general/monitor_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def start_async(self):
if (not self._mc.delayed_start):
self._mc.ensure_progress_bar()

def stop(self):
async def join(self):
"""
Clean up and close the progress bar.
"""
Expand Down
2 changes: 1 addition & 1 deletion morpheus/stages/postprocess/generate_viz_frames_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ async def run_server():

return await super().start_async()

def stop(self):
async def join(self):
"""
Stages can implement this to perform cleanup steps when pipeline is stopped.
"""
Expand Down
3 changes: 2 additions & 1 deletion tests/pipeline/test_pipe_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from morpheus.cli.commands import RANKDIR_CHOICES
from morpheus.pipeline import LinearPipeline
from morpheus.pipeline.pipeline import Pipeline
from morpheus.pipeline.pipeline import PipelineState
from morpheus.stages.input.in_memory_source_stage import InMemorySourceStage
from morpheus.stages.output.in_memory_sink_stage import InMemorySinkStage
from morpheus.stages.postprocess.add_classifications_stage import AddClassificationsStage
Expand Down Expand Up @@ -76,7 +77,7 @@ def test_viz_without_run(viz_pipeline: Pipeline, tmp_path: str):
# Verify that the output file exists and is a valid png file
assert_path_exists(viz_file)
assert imghdr.what(viz_file) == 'png'
assert viz_pipeline.is_built
assert viz_pipeline.state != PipelineState.INITIALIZED


@pytest.mark.slow
Expand Down
Loading

0 comments on commit 5fd661b

Please sign in to comment.