From 9c95a4325787e7bef6cc3be0c916382b07bd137b Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 08:56:09 -0700 Subject: [PATCH 01/26] Configure the logger named 'llm' not 'llm.cli' , allowing the logs of the pipelines to be seen --- examples/llm/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llm/cli.py b/examples/llm/cli.py index b73a1f4fc3..5d67a9db55 100644 --- a/examples/llm/cli.py +++ b/examples/llm/cli.py @@ -53,7 +53,7 @@ def cli(ctx: click.Context, log_level: int, use_cpp: bool): morpheus_logger = logging.getLogger("morpheus") - logger = logging.getLogger(__name__) + logger = logging.getLogger('.'.join(__name__.split('.')[:-1])) # Set the parent logger for all of the llm examples to use morpheus so we can take advantage of configure_logging logger.parent = morpheus_logger From 13dec682bb1c0d60108dc41b1cfb9c0f17e4a1c9 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 08:57:40 -0700 Subject: [PATCH 02/26] Fix responses to actually be the responses not a copy of the input, log the responses at the end of the run --- examples/llm/completion/pipeline.py | 7 +++++-- tests/_utils/__init__.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/examples/llm/completion/pipeline.py b/examples/llm/completion/pipeline.py index 2f2784e7a8..71ef81aa7c 100644 --- a/examples/llm/completion/pipeline.py +++ b/examples/llm/completion/pipeline.py @@ -32,6 +32,7 @@ from morpheus.stages.input.in_memory_source_stage import InMemorySourceStage from morpheus.stages.output.in_memory_sink_stage import InMemorySinkStage from morpheus.stages.preprocess.deserialize_stage import DeserializeStage +from morpheus.utils.concat_df import concat_dataframes logger = logging.getLogger(__name__) @@ -52,7 +53,7 @@ def _build_engine(): engine.add_node("completion", inputs=["/prompts"], node=LLMGenerateNode(llm_client=llm_clinet)) - engine.add_task_handler(inputs=["/extracter"], handler=SimpleTaskHandler()) + engine.add_task_handler(inputs=["/completion"], handler=SimpleTaskHandler()) return engine @@ -107,6 +108,8 @@ def pipeline(num_threads, pipeline_batch_size, model_max_batch_size, repeat_coun pipe.run() - logger.info("Pipeline complete. Received %s responses", len(sink.get_messages())) + messages = sink.get_messages() + responses = concat_dataframes(messages) + logger.info("Pipeline complete. Received %s responses\n%s", len(messages), responses['response']) return start_time diff --git a/tests/_utils/__init__.py b/tests/_utils/__init__.py index 5a8daf296d..ce4e532e3e 100644 --- a/tests/_utils/__init__.py +++ b/tests/_utils/__init__.py @@ -122,6 +122,18 @@ def import_or_skip(modname: str, raise ImportError(e) from e raise +def require_env_variable(varname: str, reason: str, fail_missing: bool = False) -> str: + """ + Checks if the given environment variable is set, and returns its value if it is. If the variable is not set, and + `fail_missing` is False the test will ve skipped, otherwise a `RuntimeError` will be raised. + """ + try: + return os.environ[varname] + except KeyError as e: + if fail_missing: + raise RuntimeError(reason) from e + + pytest.skip(reason=reason) def make_url(port: int, endpoint: str) -> str: if not endpoint.startswith("/"): From 63df23883192f8e185966e8b98d10be9ec311f6b Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 09:09:24 -0700 Subject: [PATCH 03/26] Tests for nemo llm --- tests/llm/services/nemo_llm/conftest.py | 40 +++++++++++++ .../services/nemo_llm/test_nemo_llm_client.py | 56 +++++++++++++++++++ .../nemo_llm/test_nemo_llm_service.py | 49 ++++++++++++++++ 3 files changed, 145 insertions(+) create mode 100644 tests/llm/services/nemo_llm/conftest.py create mode 100644 tests/llm/services/nemo_llm/test_nemo_llm_client.py create mode 100644 tests/llm/services/nemo_llm/test_nemo_llm_service.py diff --git a/tests/llm/services/nemo_llm/conftest.py b/tests/llm/services/nemo_llm/conftest.py new file mode 100644 index 0000000000..65160fe4ba --- /dev/null +++ b/tests/llm/services/nemo_llm/conftest.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from _utils import import_or_skip +from _utils import require_env_variable + + +@pytest.fixture(name="nemollm", autouse=True, scope='session') +def nemollm_fixture(fail_missing: bool): + """ + All of the tests in this subdir require nemollm + """ + skip_reason = ("Tests for the NeMoLLMService require the nemollm package to be installed, to install this run:\n" + "`mamba env update -n ${CONDA_DEFAULT_ENV} --file docker/conda/environments/cuda11.8_examples.yml`") + yield import_or_skip("nemollm", reason=skip_reason, fail_missing=fail_missing) + + +@pytest.fixture(name="ngc_api_key", scope='session') +def ngc_api_key_fixture(fail_missing: bool): + """ + Integration tests require an NGC API key. + """ + yield require_env_variable( + varname="NGC_API_KEY", + reason="nemo integration tests require the `NGC_API_KEY` environment variavble to be defined.", + fail_missing=fail_missing) \ No newline at end of file diff --git a/tests/llm/services/nemo_llm/test_nemo_llm_client.py b/tests/llm/services/nemo_llm/test_nemo_llm_client.py new file mode 100644 index 0000000000..3f21f42bbb --- /dev/null +++ b/tests/llm/services/nemo_llm/test_nemo_llm_client.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from morpheus.llm.services.nemo_llm_service import NeMoLLMClient + + +def _make_mock_nemo_llm(): + mock_nemo_llm = mock.MagicMock() + mock_nemo_llm.return_value = mock_nemo_llm + mock_nemo_llm.generate_multiple.return_value = ["test_output"] + return mock_nemo_llm + + +def _make_mock_nemo_service(): + mock_nemo_llm = _make_mock_nemo_llm() + mock_nemo_service = mock.MagicMock() + mock_nemo_service.return_value = mock_nemo_service + mock_nemo_service._conn = mock_nemo_llm + return (mock_nemo_service, mock_nemo_llm) + + +def test_generate(): + (mock_nemo_service, mock_nemo_llm) = _make_mock_nemo_service() + + client = NeMoLLMClient(mock_nemo_service, "test_model", additional_arg="test_arg") + assert client.generate("test_prompt") == "test_output" + mock_nemo_llm.generate_multiple.assert_called_once_with(model="test_model", + prompts=["test_prompt"], + return_type="text", + additional_arg="test_arg") + + +def test_generate_batch(): + (mock_nemo_service, mock_nemo_llm) = _make_mock_nemo_service() + mock_nemo_llm.generate_multiple.return_value = ["output1", "output2"] + + client = NeMoLLMClient(mock_nemo_service, "test_model", additional_arg="test_arg") + assert client.generate_batch(["prompt1", "prompt2"]) == ["output1", "output2"] + mock_nemo_llm.generate_multiple.assert_called_once_with(model="test_model", + prompts=["prompt1", "prompt2"], + return_type="text", + additional_arg="test_arg") diff --git a/tests/llm/services/nemo_llm/test_nemo_llm_service.py b/tests/llm/services/nemo_llm/test_nemo_llm_service.py new file mode 100644 index 0000000000..0a5ca89f14 --- /dev/null +++ b/tests/llm/services/nemo_llm/test_nemo_llm_service.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest import mock + +import pytest + +from morpheus.llm.services.nemo_llm_service import NeMoLLMService +from morpheus.llm.services.nemo_llm_service import NeMoLLMClient + + +@pytest.mark.usefixtures("restore_environ") +@pytest.mark.parametrize("api_key", [None, "test_api_key"]) +@pytest.mark.parametrize("org_id", [None, "test_org_id"]) +@mock.patch("morpheus.llm.services.nemo_llm_service.NemoLLM") +def test_constructor(mock_nemo: mock.MagicMock, api_key: str, org_id: str): + """ + Test that the constructor prefers explicit arguments over environment variables. + """ + env_api_key = "test_env_api_key" + env_org_id = "test_env_org_id" + os.environ["NGC_API_KEY"] = env_api_key + os.environ["NGC_ORG_ID"] = env_org_id + + expected_api_key = api_key or env_api_key + expected_org_id = org_id or env_org_id + + NeMoLLMService(api_key=api_key, org_id=org_id) + mock_nemo.assert_called_once_with(api_key=expected_api_key, org_id=expected_org_id) + + +def test_get_client(): + service = NeMoLLMService(api_key="test_api_key") + client = service.get_client("test_model") + + assert isinstance(client, NeMoLLMClient) From 678fd60a6a531d0ed773e542baa13b68a9d2ad47 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 09:43:55 -0700 Subject: [PATCH 04/26] Remove setting of pipeline mode to other, as this is changed to NLP a few lines later --- examples/llm/completion/pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/llm/completion/pipeline.py b/examples/llm/completion/pipeline.py index 71ef81aa7c..5528f778c4 100644 --- a/examples/llm/completion/pipeline.py +++ b/examples/llm/completion/pipeline.py @@ -61,7 +61,6 @@ def _build_engine(): def pipeline(num_threads, pipeline_batch_size, model_max_batch_size, repeat_count: int): config = Config() - config.mode = PipelineModes.OTHER # Below properties are specified by the command line config.num_threads = num_threads From 4c2da64599a3469a5fb3fc7adf4dae12c306f8af Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 10:12:06 -0700 Subject: [PATCH 05/26] End to end test using NeMoLLMService in a pipeline --- .../nemo_llm/test_nemo_llm_service_pipe.py | 92 +++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 tests/llm/services/nemo_llm/test_nemo_llm_service_pipe.py diff --git a/tests/llm/services/nemo_llm/test_nemo_llm_service_pipe.py b/tests/llm/services/nemo_llm/test_nemo_llm_service_pipe.py new file mode 100644 index 0000000000..86fc9c5406 --- /dev/null +++ b/tests/llm/services/nemo_llm/test_nemo_llm_service_pipe.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest import mock + +import pytest + +import cudf + +from morpheus.config import Config +from morpheus.llm import LLMEngine +from morpheus.llm.llm_engine_stage import LLMEngineStage +from morpheus.llm.nodes.extracter_node import ExtracterNode +from morpheus.llm.nodes.llm_generate_node import LLMGenerateNode +from morpheus.llm.services.nemo_llm_service import NeMoLLMService +from morpheus.llm.task_handlers.simple_task_handler import SimpleTaskHandler +from morpheus.messages import ControlMessage +from morpheus.pipeline.linear_pipeline import LinearPipeline +from morpheus.stages.input.in_memory_source_stage import InMemorySourceStage +from morpheus.stages.output.in_memory_sink_stage import InMemorySinkStage +from morpheus.stages.preprocess.deserialize_stage import DeserializeStage +from morpheus.utils.concat_df import concat_dataframes + + +def _build_engine(model_name: str): + llm_service = NeMoLLMService() + llm_clinet = llm_service.get_client(model_name=model_name) + + engine = LLMEngine() + engine.add_node("extracter", node=ExtracterNode()) + engine.add_node("completion", inputs=["/extracter"], node=LLMGenerateNode(llm_client=llm_clinet)) + engine.add_task_handler(inputs=["/completion"], handler=SimpleTaskHandler()) + + return engine + + +@pytest.mark.slow +@pytest.mark.use_python +@pytest.mark.usefixtures("ngc_api_key") +@pytest.mark.parametrize("model_name", ["gpt-43b-002"]) +def test_completion_pipe(config: Config, model_name: str): + """ + Loosely patterned after `examples/llm/completion` + """ + + source_df = cudf.DataFrame({ + "prompt": [ + "What is the capital of France?", + "What is the capital of Spain?", + "What is the capital of Italy?", + "What is the capital of Germany?", + "What is the capital of United Kingdom?", + "What is the capital of China?", + "What is the capital of Japan?", + "What is the capital of India?", + "What is the capital of Brazil?", + "What is the capital of United States?", + ] + }) + + completion_task = {"task_type": "completion", "task_dict": {"input_keys": ["prompt"], }} + + pipe = LinearPipeline(config) + + pipe.set_source(InMemorySourceStage(config, dataframes=[source_df])) + + pipe.add_stage( + DeserializeStage(config, message_type=ControlMessage, task_type="llm_engine", task_payload=completion_task)) + + pipe.add_stage(LLMEngineStage(config, engine=_build_engine(model_name=model_name))) + sink = pipe.add_stage(InMemorySinkStage(config)) + + pipe.run() + + messages = sink.get_messages() + result_df = concat_dataframes(messages) + + # We don't want to check for specific responses from Nemo, we just want to ensure we received non-empty responses + assert (result_df['response'].str.strip() != '').all() From 6c0f0449d69aef6dd4a52b86dec439bebfa61dca Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 10:19:21 -0700 Subject: [PATCH 06/26] Remove unused imports --- tests/llm/services/nemo_llm/test_nemo_llm_service_pipe.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/llm/services/nemo_llm/test_nemo_llm_service_pipe.py b/tests/llm/services/nemo_llm/test_nemo_llm_service_pipe.py index 86fc9c5406..82584825d9 100644 --- a/tests/llm/services/nemo_llm/test_nemo_llm_service_pipe.py +++ b/tests/llm/services/nemo_llm/test_nemo_llm_service_pipe.py @@ -13,9 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -from unittest import mock - import pytest import cudf From e0ed8a472813e2047120670119ec1042d60f281d Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 10:39:12 -0700 Subject: [PATCH 07/26] Formatting fixes --- tests/_utils/__init__.py | 8 +++++++- tests/llm/services/nemo_llm/conftest.py | 2 +- tests/llm/services/nemo_llm/test_nemo_llm_service.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/_utils/__init__.py b/tests/_utils/__init__.py index ce4e532e3e..dc9aa515a0 100644 --- a/tests/_utils/__init__.py +++ b/tests/_utils/__init__.py @@ -122,6 +122,8 @@ def import_or_skip(modname: str, raise ImportError(e) from e raise + +# pylint: disable=inconsistent-return-statements def require_env_variable(varname: str, reason: str, fail_missing: bool = False) -> str: """ Checks if the given environment variable is set, and returns its value if it is. If the variable is not set, and @@ -132,9 +134,13 @@ def require_env_variable(varname: str, reason: str, fail_missing: bool = False) except KeyError as e: if fail_missing: raise RuntimeError(reason) from e - + pytest.skip(reason=reason) + +# pylint: enable=inconsistent-return-statements + + def make_url(port: int, endpoint: str) -> str: if not endpoint.startswith("/"): endpoint = "/" + endpoint diff --git a/tests/llm/services/nemo_llm/conftest.py b/tests/llm/services/nemo_llm/conftest.py index 65160fe4ba..3098cd2223 100644 --- a/tests/llm/services/nemo_llm/conftest.py +++ b/tests/llm/services/nemo_llm/conftest.py @@ -37,4 +37,4 @@ def ngc_api_key_fixture(fail_missing: bool): yield require_env_variable( varname="NGC_API_KEY", reason="nemo integration tests require the `NGC_API_KEY` environment variavble to be defined.", - fail_missing=fail_missing) \ No newline at end of file + fail_missing=fail_missing) diff --git a/tests/llm/services/nemo_llm/test_nemo_llm_service.py b/tests/llm/services/nemo_llm/test_nemo_llm_service.py index 0a5ca89f14..858295ca93 100644 --- a/tests/llm/services/nemo_llm/test_nemo_llm_service.py +++ b/tests/llm/services/nemo_llm/test_nemo_llm_service.py @@ -18,8 +18,8 @@ import pytest -from morpheus.llm.services.nemo_llm_service import NeMoLLMService from morpheus.llm.services.nemo_llm_service import NeMoLLMClient +from morpheus.llm.services.nemo_llm_service import NeMoLLMService @pytest.mark.usefixtures("restore_environ") From b56e70215e67b1d0b7df3a9352bcd8c5d75b3365 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 11:34:26 -0700 Subject: [PATCH 08/26] Fix type-o --- morpheus/stages/input/arxiv_source.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/morpheus/stages/input/arxiv_source.py b/morpheus/stages/input/arxiv_source.py index e47464590f..1be27fc9d7 100644 --- a/morpheus/stages/input/arxiv_source.py +++ b/morpheus/stages/input/arxiv_source.py @@ -35,7 +35,7 @@ from langchain.schema import Document IMPORT_ERROR_MESSAGE = ( - "ArxivSource requires additional dependencies to be installed. Install them by runnign the following command: " + "ArxivSource requires additional dependencies to be installed. Install them by running the following command: " "`mamba env update -n ${CONDA_DEFAULT_ENV} --file docker/conda/environments/cuda11.8_examples.yml`") From d2fe7484dcc6f04593dcc7222de990c02abeb9d6 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 11:35:33 -0700 Subject: [PATCH 09/26] Better import error handling, add docstring for NeMoLLMService constructor, NeMoLLMService doesn't need to hold a reference to the api key or the org id --- morpheus/llm/services/nemo_llm_service.py | 39 ++++++++++++++++++----- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/morpheus/llm/services/nemo_llm_service.py b/morpheus/llm/services/nemo_llm_service.py index 839ab5f918..8c2d844311 100644 --- a/morpheus/llm/services/nemo_llm_service.py +++ b/morpheus/llm/services/nemo_llm_service.py @@ -22,16 +22,29 @@ logger = logging.getLogger(__name__) +IMPORT_ERROR_MESSAGE = ( + "NemoLLM not found. Install it and other additional dependencies by running the following command:\n" + "`mamba env update -n ${CONDA_DEFAULT_ENV} --file docker/conda/environments/cuda11.8_examples.yml`") + try: from nemollm.api import NemoLLM except ImportError: - logger.error("NemoLLM not found. Please install NemoLLM to use this service.") + logger.error(IMPORT_ERROR_MESSAGE) + + +def _verify_nemo_llm(): + """ + When NemoLLM is not installed, raise an ImportError with a helpful message, rather than an attribute error. + """ + if 'NemoLLM' not in globals(): + raise ImportError(IMPORT_ERROR_MESSAGE) class NeMoLLMClient(LLMClient): def __init__(self, parent: "NeMoLLMService", model_name: str, **model_kwargs) -> None: super().__init__() + _verify_nemo_llm() self._parent = parent self._model_name = model_name @@ -69,29 +82,39 @@ async def generate_batch_async(self, prompts: list[str]) -> list[str]: class NeMoLLMService(LLMService): + """ + A service for interacting with NeMo LLM models, this class should be used to create a client for a specific model. + + Parameters + ---------- + api_key : str, optional + The API key for the LLM service, by default None. If `None` the API key will be read from the `NGC_API_KEY` + environment variable. If neither are present an error will be raised. + + org_id : str, optional + The organization ID for the LLM service, by default None. If `None` the organization ID will be read from the + `NGC_ORG_ID` environment variable. This value is only required if the account associated with the `api_key` is + a member of multiple NGC organizations. + """ def __init__(self, *, api_key: str = None, org_id: str = None) -> None: super().__init__() + _verify_nemo_llm() api_key = api_key if api_key is not None else os.environ.get("NGC_API_KEY", None) org_id = org_id if org_id is not None else os.environ.get("NGC_ORG_ID", None) - self._api_key = api_key - self._org_id = org_id - - # Do checking on api key - # Class variables self._conn: NemoLLM = NemoLLM( # The client must configure the authentication and authorization parameters # in accordance with the API server security policy. # Configure Bearer authorization - api_key=self._api_key, + api_key=api_key, # If you are in more than one LLM-enabled organization, you must # specify your org ID in the form of a header. This is optional # if you are only in one LLM-enabled org. - org_id=self._org_id, + org_id=org_id, ) def get_client(self, model_name: str, **model_kwargs) -> NeMoLLMClient: From ca2376bdf659bf3153f23ae93d69b66d02f16827 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 12:00:19 -0700 Subject: [PATCH 10/26] Add docstrings [no ci] --- morpheus/llm/services/nemo_llm_service.py | 64 +++++++++++++++++++++-- 1 file changed, 59 insertions(+), 5 deletions(-) diff --git a/morpheus/llm/services/nemo_llm_service.py b/morpheus/llm/services/nemo_llm_service.py index 8c2d844311..9aa5649d25 100644 --- a/morpheus/llm/services/nemo_llm_service.py +++ b/morpheus/llm/services/nemo_llm_service.py @@ -41,8 +41,22 @@ def _verify_nemo_llm(): class NeMoLLMClient(LLMClient): + """ + Client for interacting with a specific model in Nemo. This class should be constructed with the + `NeMoLLMService.get_client` method. + + Parameters + ---------- + parent : NeMoLLMService + The parent service for this client. + model_name : str + The name of the model to interact with. + + model_kwargs : dict[str, typing.Any] + Additional keyword arguments to pass to the model when generating text. + """ - def __init__(self, parent: "NeMoLLMService", model_name: str, **model_kwargs) -> None: + def __init__(self, parent: "NeMoLLMService", model_name: str, **model_kwargs: dict[str, typing.Any]) -> None: super().__init__() _verify_nemo_llm() @@ -51,13 +65,36 @@ def __init__(self, parent: "NeMoLLMService", model_name: str, **model_kwargs) -> self._model_kwargs = model_kwargs def generate(self, prompt: str) -> str: + """ + Issue a request to generate a response based on a given prompt. + + Parameters + ---------- + prompt : str + The prompt to generate a response for. + """ return self.generate_batch([prompt])[0] async def generate_async(self, prompt: str) -> str: + """ + Issue an asynchronous request to generate a response based on a given prompt. + + Parameters + ---------- + prompt : str + The prompt to generate a response for. + """ return (await self.generate_batch_async([prompt]))[0] def generate_batch(self, prompts: list[str]) -> list[str]: - + """ + Issue a request to generate a list of responses based on a list of prompts. + + Parameters + ---------- + prompts : list[str] + The prompts to generate responses for. + """ return typing.cast( list[str], self._parent._conn.generate_multiple(model=self._model_name, @@ -66,7 +103,14 @@ def generate_batch(self, prompts: list[str]) -> list[str]: **self._model_kwargs)) async def generate_batch_async(self, prompts: list[str]) -> list[str]: - + """ + Issue an asynchronous request to generate a list of responses based on a list of prompts. + + Parameters + ---------- + prompts : list[str] + The prompts to generate responses for. + """ futures = [ asyncio.wrap_future( self._parent._conn.generate(self._model_name, p, return_type="async", **self._model_kwargs)) @@ -104,7 +148,6 @@ def __init__(self, *, api_key: str = None, org_id: str = None) -> None: api_key = api_key if api_key is not None else os.environ.get("NGC_API_KEY", None) org_id = org_id if org_id is not None else os.environ.get("NGC_ORG_ID", None) - # Class variables self._conn: NemoLLM = NemoLLM( # The client must configure the authentication and authorization parameters # in accordance with the API server security policy. @@ -117,6 +160,17 @@ def __init__(self, *, api_key: str = None, org_id: str = None) -> None: org_id=org_id, ) - def get_client(self, model_name: str, **model_kwargs) -> NeMoLLMClient: + def get_client(self, model_name: str, **model_kwargs: dict[str, typing.Any]) -> NeMoLLMClient: + """ + Returns a client for interacting with a specific model. This method is the preferred way to create a client. + + Parameters + ---------- + model_name : str + The name of the model to create a client for. + + model_kwargs : dict[str, typing.Any] + Additional keyword arguments to pass to the model when generating text. + """ return NeMoLLMClient(self, model_name, **model_kwargs) From daddaad87bab6f157a04313adf3d283d12515481 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 12:17:40 -0700 Subject: [PATCH 11/26] Add NGC_API_KEY to env for tests, explictly eclude this value from being logged --- .github/workflows/ci_pipe.yml | 1 + ci/scripts/github/common.sh | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci_pipe.yml b/.github/workflows/ci_pipe.yml index c1a92cdbd5..23a73111c1 100644 --- a/.github/workflows/ci_pipe.yml +++ b/.github/workflows/ci_pipe.yml @@ -133,6 +133,7 @@ jobs: NVIDIA_VISIBLE_DEVICES: ${{ env.NVIDIA_VISIBLE_DEVICES }} PARALLEL_LEVEL: '10' MERGE_EXAMPLES_YAML: '1' + NGC_API_KEY: ${{ secrets.NGC_API_KEY }} strategy: fail-fast: true diff --git a/ci/scripts/github/common.sh b/ci/scripts/github/common.sh index dd73852e4d..8756ea09bd 100644 --- a/ci/scripts/github/common.sh +++ b/ci/scripts/github/common.sh @@ -16,7 +16,7 @@ function print_env_vars() { rapids-logger "Environ:" - env | grep -v -E "AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|GH_TOKEN" | sort + env | grep -v -E "AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|GH_TOKEN|NGC_API_KEY" | sort } rapids-logger "Env Setup" From b3528c1df8f9587b9ebdadb864f3c51d542c9f77 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 13:39:09 -0700 Subject: [PATCH 12/26] Check response status for errors [no ci] --- morpheus/llm/services/nemo_llm_service.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/morpheus/llm/services/nemo_llm_service.py b/morpheus/llm/services/nemo_llm_service.py index 9aa5649d25..2012ac9faa 100644 --- a/morpheus/llm/services/nemo_llm_service.py +++ b/morpheus/llm/services/nemo_llm_service.py @@ -119,10 +119,16 @@ async def generate_batch_async(self, prompts: list[str]) -> list[str]: results = await asyncio.gather(*futures) - return [ - typing.cast(str, NemoLLM.post_process_generate_response(r, return_text_completion_only=True)) - for r in results - ] + responses = [] + + for result in results: + result = NemoLLM.post_process_generate_response(result, return_text_completion_only=False) + if result.get('status', None) == 'fail': + raise RuntimeError(result.get('msg', 'Unknown error')) + + responses.append(result['text']) + + return responses class NeMoLLMService(LLMService): From 3925f87e009564279acde78bf46901a8af110b7f Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 13:39:09 -0700 Subject: [PATCH 13/26] Check response status for errors --- morpheus/llm/services/nemo_llm_service.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/morpheus/llm/services/nemo_llm_service.py b/morpheus/llm/services/nemo_llm_service.py index 9aa5649d25..2012ac9faa 100644 --- a/morpheus/llm/services/nemo_llm_service.py +++ b/morpheus/llm/services/nemo_llm_service.py @@ -119,10 +119,16 @@ async def generate_batch_async(self, prompts: list[str]) -> list[str]: results = await asyncio.gather(*futures) - return [ - typing.cast(str, NemoLLM.post_process_generate_response(r, return_text_completion_only=True)) - for r in results - ] + responses = [] + + for result in results: + result = NemoLLM.post_process_generate_response(result, return_text_completion_only=False) + if result.get('status', None) == 'fail': + raise RuntimeError(result.get('msg', 'Unknown error')) + + responses.append(result['text']) + + return responses class NeMoLLMService(LLMService): From ba0ce54a5292a921da25b8604d3f12cb2ee4fa59 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 14:43:35 -0700 Subject: [PATCH 14/26] Disable end-to-end test --- .github/workflows/ci_pipe.yml | 1 - tests/llm/services/nemo_llm/test_nemo_llm_service_pipe.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci_pipe.yml b/.github/workflows/ci_pipe.yml index 23a73111c1..c1a92cdbd5 100644 --- a/.github/workflows/ci_pipe.yml +++ b/.github/workflows/ci_pipe.yml @@ -133,7 +133,6 @@ jobs: NVIDIA_VISIBLE_DEVICES: ${{ env.NVIDIA_VISIBLE_DEVICES }} PARALLEL_LEVEL: '10' MERGE_EXAMPLES_YAML: '1' - NGC_API_KEY: ${{ secrets.NGC_API_KEY }} strategy: fail-fast: true diff --git a/tests/llm/services/nemo_llm/test_nemo_llm_service_pipe.py b/tests/llm/services/nemo_llm/test_nemo_llm_service_pipe.py index 82584825d9..767c7e413e 100644 --- a/tests/llm/services/nemo_llm/test_nemo_llm_service_pipe.py +++ b/tests/llm/services/nemo_llm/test_nemo_llm_service_pipe.py @@ -44,6 +44,7 @@ def _build_engine(model_name: str): return engine +@pytest.mark.skip(reason="Skipping until we can generate a new API key for the test account") @pytest.mark.slow @pytest.mark.use_python @pytest.mark.usefixtures("ngc_api_key") From b23f231a938b44d540ad8ac1978f078d1890e043 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 16:01:14 -0700 Subject: [PATCH 15/26] Tests and docstrings --- morpheus/llm/services/llm_service.py | 46 +++++++++++++++++++++++++- tests/llm/services/test_llm_service.py | 27 +++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 tests/llm/services/test_llm_service.py diff --git a/morpheus/llm/services/llm_service.py b/morpheus/llm/services/llm_service.py index 4e651195e2..4229955a2a 100644 --- a/morpheus/llm/services/llm_service.py +++ b/morpheus/llm/services/llm_service.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import typing from abc import ABC from abc import abstractmethod @@ -23,23 +24,66 @@ class LLMClient(ABC): @abstractmethod def generate(self, prompt: str) -> str: + """ + Issue a request to generate a response based on a given prompt. + + Parameters + ---------- + prompt : str + The prompt to generate a response for. + """ pass @abstractmethod async def generate_async(self, prompt: str) -> str: + """ + Issue an asynchronous request to generate a response based on a given prompt. + + Parameters + ---------- + prompt : str + The prompt to generate a response for. + """ pass @abstractmethod def generate_batch(self, prompts: list[str]) -> list[str]: + """ + Issue a request to generate a list of responses based on a list of prompts. + + Parameters + ---------- + prompts : list[str] + The prompts to generate responses for. + """ pass @abstractmethod async def generate_batch_async(self, prompts: list[str]) -> list[str]: + """ + Issue an asynchronous request to generate a list of responses based on a list of prompts. + + Parameters + ---------- + prompts : list[str] + The prompts to generate responses for. + """ pass class LLMService(ABC): @abstractmethod - def get_client(self, model_name: str, **model_kwargs) -> LLMClient: + def get_client(self, model_name: str, **model_kwargs: dict[str, typing.Any]) -> LLMClient: + """ + Returns a client for interacting with a specific model. + + Parameters + ---------- + model_name : str + The name of the model to create a client for. + + model_kwargs : dict[str, typing.Any] + Additional keyword arguments to pass to the model when generating text. + """ pass diff --git a/tests/llm/services/test_llm_service.py b/tests/llm/services/test_llm_service.py new file mode 100644 index 0000000000..9060963899 --- /dev/null +++ b/tests/llm/services/test_llm_service.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from abc import ABC + +import pytest + +from morpheus.llm.services.llm_service import LLMClient +from morpheus.llm.services.llm_service import LLMService + + +@pytest.mark.parametrize("cls", [LLMClient, LLMService]) +def test_is_abstract(cls: ABC): + assert inspect.isabstract(cls) From 2564a0a51e44aa69bca42fc2b65cd65d792e93bc Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 16:03:13 -0700 Subject: [PATCH 16/26] workaround for https://github.com/nv-morpheus/Morpheus/issues/1317 --- docker/conda/environments/cuda11.8_dev.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/conda/environments/cuda11.8_dev.yml b/docker/conda/environments/cuda11.8_dev.yml index 5785546684..9d25275e13 100644 --- a/docker/conda/environments/cuda11.8_dev.yml +++ b/docker/conda/environments/cuda11.8_dev.yml @@ -34,6 +34,7 @@ dependencies: - configargparse=1.5 - cuda-compiler=11.8 - cuda-nvml-dev=11.8 + - cuda-python>=11.8,<11.8.3 # workaround for https://github.com/nv-morpheus/Morpheus/issues/1317 - cuda-toolkit=11.8 - cudf=23.06 - cupy>=12.0.0 From 747eff9399792cad41b2e95e1eea9e4fdfdae1b5 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 16:03:13 -0700 Subject: [PATCH 17/26] workaround for https://github.com/nv-morpheus/Morpheus/issues/1317 --- docker/conda/environments/cuda11.8_dev.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/conda/environments/cuda11.8_dev.yml b/docker/conda/environments/cuda11.8_dev.yml index 5785546684..9d25275e13 100644 --- a/docker/conda/environments/cuda11.8_dev.yml +++ b/docker/conda/environments/cuda11.8_dev.yml @@ -34,6 +34,7 @@ dependencies: - configargparse=1.5 - cuda-compiler=11.8 - cuda-nvml-dev=11.8 + - cuda-python>=11.8,<11.8.3 # workaround for https://github.com/nv-morpheus/Morpheus/issues/1317 - cuda-toolkit=11.8 - cudf=23.06 - cupy>=12.0.0 From e07a35f8b8f13d42361b4641435322bf594a03c2 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 16:16:58 -0700 Subject: [PATCH 18/26] More docstrings --- morpheus/llm/services/llm_service.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/morpheus/llm/services/llm_service.py b/morpheus/llm/services/llm_service.py index 4229955a2a..16fd6f0d9d 100644 --- a/morpheus/llm/services/llm_service.py +++ b/morpheus/llm/services/llm_service.py @@ -21,6 +21,10 @@ class LLMClient(ABC): + """ + Abstract interface for clients which are able to interact with LLM models. Concrete implementations of this class + will have an associated implementation of `LLMService` which is able to construct instances of this class. + """ @abstractmethod def generate(self, prompt: str) -> str: @@ -72,6 +76,9 @@ async def generate_batch_async(self, prompts: list[str]) -> list[str]: class LLMService(ABC): + """ + Abstract interface for services which are able to construct clients for interacting with LLM models. + """ @abstractmethod def get_client(self, model_name: str, **model_kwargs: dict[str, typing.Any]) -> LLMClient: @@ -84,6 +91,6 @@ def get_client(self, model_name: str, **model_kwargs: dict[str, typing.Any]) -> The name of the model to create a client for. model_kwargs : dict[str, typing.Any] - Additional keyword arguments to pass to the model when generating text. + Additional keyword arguments to pass to the model. """ pass From 4030e558acdd854fe2a833d3876822b9d17cbbf2 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 16:03:13 -0700 Subject: [PATCH 19/26] workaround for https://github.com/nv-morpheus/Morpheus/issues/1317 --- docker/conda/environments/cuda11.8_dev.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/conda/environments/cuda11.8_dev.yml b/docker/conda/environments/cuda11.8_dev.yml index 5785546684..9d25275e13 100644 --- a/docker/conda/environments/cuda11.8_dev.yml +++ b/docker/conda/environments/cuda11.8_dev.yml @@ -34,6 +34,7 @@ dependencies: - configargparse=1.5 - cuda-compiler=11.8 - cuda-nvml-dev=11.8 + - cuda-python>=11.8,<11.8.3 # workaround for https://github.com/nv-morpheus/Morpheus/issues/1317 - cuda-toolkit=11.8 - cudf=23.06 - cupy>=12.0.0 From 5ce03d2887954edce1d3251a4abbcbbd682bc2f3 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 19:46:40 -0700 Subject: [PATCH 20/26] Add docstrings for base llm node class --- morpheus/_lib/llm/__init__.pyi | 25 +++++++++++++++++++++++-- morpheus/_lib/llm/module.cpp | 29 +++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/morpheus/_lib/llm/__init__.pyi b/morpheus/_lib/llm/__init__.pyi index c2eefa4f3f..fa7a3a0906 100644 --- a/morpheus/_lib/llm/__init__.pyi +++ b/morpheus/_lib/llm/__init__.pyi @@ -101,8 +101,29 @@ class LLMContext(): pass class LLMNodeBase(): def __init__(self) -> None: ... - def execute(self, context: LLMContext) -> typing.Awaitable[LLMContext]: ... - def get_input_names(self) -> typing.List[str]: ... + def execute(self, context: LLMContext) -> typing.Awaitable[LLMContext]: + """ + Execute the current node with the given `context` instance. + + All inputs for the given node should be fetched from the context, typically by calling either + `context.get_inputs` to fetch all inputs as a `dict`, or `context.get_input` to fetch a specific input. + + Similarly the output of the node is written to the context using `context.set_output`. + + Parameters + ---------- + context : `morpheus._lib.llm.LLMContext` + Context instance to use for the execution + """ + def get_input_names(self) -> typing.List[str]: + """ + Get the input names for the node. + + Returns + ------- + list[str] + The input names for the node + """ pass class LLMEngineStage(mrc.core.segment.SegmentObject): def __init__(self, builder: mrc.core.segment.Builder, name: str, engine: LLMEngine) -> None: ... diff --git a/morpheus/_lib/llm/module.cpp b/morpheus/_lib/llm/module.cpp index 2304aeb7c4..a0394a0cc6 100644 --- a/morpheus/_lib/llm/module.cpp +++ b/morpheus/_lib/llm/module.cpp @@ -212,8 +212,33 @@ PYBIND11_MODULE(llm, _module) py::class_, std::shared_ptr>(_module, "LLMNodeBase") .def(py::init_alias<>()) - .def("get_input_names", &LLMNodeBase::get_input_names) - .def("execute", &LLMNodeBase::execute, py::arg("context")); + .def("get_input_names", + &LLMNodeBase::get_input_names, + R"pbdoc( + Get the input names for the node. + + Returns + ------- + list[str] + The input names for the node + )pbdoc") + .def("execute", + &LLMNodeBase::execute, + py::arg("context"), + R"pbdoc( + Execute the current node with the given `context` instance. + + All inputs for the given node should be fetched from the context, typically by calling either + `context.get_inputs` to fetch all inputs as a `dict`, or `context.get_input` to fetch a specific input. + + Similarly the output of the node is written to the context using `context.set_output`. + + Parameters + ---------- + context : `morpheus._lib.llm.LLMContext` + Context instance to use for the execution + + )pbdoc"); py::class_>(_module, "LLMNodeRunner") .def_property_readonly("inputs", &LLMNodeRunner::inputs) From e5c16616ace84bbd3f442de77a82fa38106eb2db Mon Sep 17 00:00:00 2001 From: David Gardner Date: Thu, 26 Oct 2023 19:46:57 -0700 Subject: [PATCH 21/26] Docstrings for constructor --- morpheus/llm/nodes/llm_generate_node.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/morpheus/llm/nodes/llm_generate_node.py b/morpheus/llm/nodes/llm_generate_node.py index e7e5fcb0c1..0f9896526e 100644 --- a/morpheus/llm/nodes/llm_generate_node.py +++ b/morpheus/llm/nodes/llm_generate_node.py @@ -23,6 +23,15 @@ class LLMGenerateNode(LLMNodeBase): + """ + Generates responses from an LLM using the provided `llm_client` instance based on prompts provided as input from + upstream nodes. + + Parameters + ---------- + llm_client : LLMClient + The client instance to use to generate responses. + """ def __init__(self, llm_client: LLMClient) -> None: super().__init__() From 49899802275e7914dcdd26f4ccf2e0b351e06d29 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Fri, 27 Oct 2023 08:31:22 -0700 Subject: [PATCH 22/26] Remove line-breaks from command allowing it to be easily copy/pasted --- examples/llm/vdb_upload/README.md | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/llm/vdb_upload/README.md b/examples/llm/vdb_upload/README.md index d67ec14aeb..e8c933cac2 100644 --- a/examples/llm/vdb_upload/README.md +++ b/examples/llm/vdb_upload/README.md @@ -125,10 +125,7 @@ Before running the pipeline, we need to ensure that the following services are r - From the Morpheus repo root directory, run the following to launch Triton and load the `all-MiniLM-L6-v2` model: ```bash - docker run --rm -ti --gpus=all -p8000:8000 -p8001:8001 -p8002:8002 - -v $PWD/models:/models nvcr.io/nvidia/tritonserver:23.06-py3 tritonserver - --model-repository=/models/triton-model-repo --exit-on-error=false --model-control-mode=explicit - --load-model all-MiniLM-L6-v2 + docker run --rm -ti --gpus=all -p8000:8000 -p8001:8001 -p8002:8002 -v $PWD/models:/models nvcr.io/nvidia/tritonserver:23.06-py3 tritonserver --model-repository=/models/triton-model-repo --exit-on-error=false --model-control-mode=explicit --load-model all-MiniLM-L6-v2 ``` This will launch Triton and only load the `all-MiniLM-L6-v2` model. Once Triton has loaded the model, the following From 78c6d7838f6051509fde704b030a686a677de2ed Mon Sep 17 00:00:00 2001 From: David Gardner Date: Fri, 27 Oct 2023 09:06:57 -0700 Subject: [PATCH 23/26] Fix CR year --- tests/llm/nodes/test_prompt_template_node.py | 2 +- tests/llm/nodes/test_prompt_template_node_pipe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/llm/nodes/test_prompt_template_node.py b/tests/llm/nodes/test_prompt_template_node.py index 79a2e34ba3..fc7ea02170 100644 --- a/tests/llm/nodes/test_prompt_template_node.py +++ b/tests/llm/nodes/test_prompt_template_node.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/llm/nodes/test_prompt_template_node_pipe.py b/tests/llm/nodes/test_prompt_template_node_pipe.py index 57d5c9063f..38ff5e8bd9 100644 --- a/tests/llm/nodes/test_prompt_template_node_pipe.py +++ b/tests/llm/nodes/test_prompt_template_node_pipe.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); From 48d6abcccccc30c1c0bf5064308dd5358ccf31c3 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Fri, 27 Oct 2023 09:30:50 -0700 Subject: [PATCH 24/26] Unittests for LLMGenerateNode --- tests/llm/nodes/test_llm_generate_node.py | 42 +++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 tests/llm/nodes/test_llm_generate_node.py diff --git a/tests/llm/nodes/test_llm_generate_node.py b/tests/llm/nodes/test_llm_generate_node.py new file mode 100644 index 0000000000..5ea1ab9ea1 --- /dev/null +++ b/tests/llm/nodes/test_llm_generate_node.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from _utils.llm import execute_node +from morpheus.llm import LLMNodeBase +from morpheus.llm.nodes.llm_generate_node import LLMGenerateNode +from morpheus.llm.services.llm_service import LLMClient + + +def test_constructor(): + node = LLMGenerateNode(llm_client=mock.MagicMock(LLMClient)) + assert isinstance(node, LLMNodeBase) + + +def test_get_input_names(): + node = LLMGenerateNode(llm_client=mock.MagicMock(LLMClient)) + assert node.get_input_names() == ["prompt"] + + +def test_execute(): + expected_output = ["response1", "response2"] + mock_client = mock.MagicMock(LLMClient) + mock_client.return_value = mock_client + mock_client.generate_batch_async = mock.AsyncMock(return_value=expected_output.copy()) + + node = LLMGenerateNode(llm_client=mock_client) + assert execute_node(node, prompt=["prompt1", "prompt2"]) == expected_output + mock_client.generate_batch_async.assert_called_once_with(["prompt1", "prompt2"]) From 972e90a34a0027673972b05e2e9c2eeb4a915432 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Fri, 27 Oct 2023 09:58:19 -0700 Subject: [PATCH 25/26] Create an end-to-end test --- tests/llm/nodes/conftest.py | 27 ++++++++ tests/llm/nodes/test_llm_generate_node.py | 19 +++--- .../llm/nodes/test_llm_generate_node_pipe.py | 66 +++++++++++++++++++ 3 files changed, 101 insertions(+), 11 deletions(-) create mode 100644 tests/llm/nodes/conftest.py create mode 100644 tests/llm/nodes/test_llm_generate_node_pipe.py diff --git a/tests/llm/nodes/conftest.py b/tests/llm/nodes/conftest.py new file mode 100644 index 0000000000..7be47d2850 --- /dev/null +++ b/tests/llm/nodes/conftest.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +import pytest + + +@pytest.fixture(name="mock_llm_client") +def mock_llm_client_fixture(): + from morpheus.llm.services.llm_service import LLMClient + mock_client = mock.MagicMock(LLMClient) + mock_client.return_value = mock_client + mock_client.generate_batch_async = mock.AsyncMock() + return mock_client diff --git a/tests/llm/nodes/test_llm_generate_node.py b/tests/llm/nodes/test_llm_generate_node.py index 5ea1ab9ea1..b3f4cf5b8f 100644 --- a/tests/llm/nodes/test_llm_generate_node.py +++ b/tests/llm/nodes/test_llm_generate_node.py @@ -18,25 +18,22 @@ from _utils.llm import execute_node from morpheus.llm import LLMNodeBase from morpheus.llm.nodes.llm_generate_node import LLMGenerateNode -from morpheus.llm.services.llm_service import LLMClient -def test_constructor(): - node = LLMGenerateNode(llm_client=mock.MagicMock(LLMClient)) +def test_constructor(mock_llm_client: mock.MagicMock): + node = LLMGenerateNode(llm_client=mock_llm_client) assert isinstance(node, LLMNodeBase) -def test_get_input_names(): - node = LLMGenerateNode(llm_client=mock.MagicMock(LLMClient)) +def test_get_input_names(mock_llm_client: mock.MagicMock): + node = LLMGenerateNode(llm_client=mock_llm_client) assert node.get_input_names() == ["prompt"] -def test_execute(): +def test_execute(mock_llm_client: mock.MagicMock): expected_output = ["response1", "response2"] - mock_client = mock.MagicMock(LLMClient) - mock_client.return_value = mock_client - mock_client.generate_batch_async = mock.AsyncMock(return_value=expected_output.copy()) + mock_llm_client.generate_batch_async.return_value = expected_output.copy() - node = LLMGenerateNode(llm_client=mock_client) + node = LLMGenerateNode(llm_client=mock_llm_client) assert execute_node(node, prompt=["prompt1", "prompt2"]) == expected_output - mock_client.generate_batch_async.assert_called_once_with(["prompt1", "prompt2"]) + mock_llm_client.generate_batch_async.assert_called_once_with(["prompt1", "prompt2"]) diff --git a/tests/llm/nodes/test_llm_generate_node_pipe.py b/tests/llm/nodes/test_llm_generate_node_pipe.py new file mode 100644 index 0000000000..61e537c648 --- /dev/null +++ b/tests/llm/nodes/test_llm_generate_node_pipe.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +import pytest + +import cudf + +from _utils import assert_results +from morpheus.config import Config +from morpheus.llm import LLMEngine +from morpheus.llm.llm_engine_stage import LLMEngineStage +from morpheus.llm.nodes.extracter_node import ExtracterNode +from morpheus.llm.nodes.llm_generate_node import LLMGenerateNode +from morpheus.llm.task_handlers.simple_task_handler import SimpleTaskHandler +from morpheus.messages import ControlMessage +from morpheus.pipeline.linear_pipeline import LinearPipeline +from morpheus.stages.input.in_memory_source_stage import InMemorySourceStage +from morpheus.stages.output.compare_dataframe_stage import CompareDataFrameStage +from morpheus.stages.preprocess.deserialize_stage import DeserializeStage + + +def _build_engine(mock_llm_client: mock.MagicMock) -> LLMEngine: + engine = LLMEngine() + engine.add_node("extracter", node=ExtracterNode()) + engine.add_node("generate", inputs=["/extracter"], node=LLMGenerateNode(llm_client=mock_llm_client)) + engine.add_task_handler(inputs=["/generate"], handler=SimpleTaskHandler()) + + return engine + + +@pytest.mark.use_python +def test_pipeline(config: Config, mock_llm_client: mock.MagicMock): + expected_output = ["response1", "response2"] + mock_llm_client.generate_batch_async.return_value = expected_output.copy() + + values = {'prompt': ["prompt1", "prompt2"]} + input_df = cudf.DataFrame(values) + expected_df = input_df.copy(deep=True) + expected_df["response"] = expected_output + + task_payload = {"task_type": "llm_engine", "task_dict": {"input_keys": sorted(values.keys())}} + + pipe = LinearPipeline(config) + pipe.set_source(InMemorySourceStage(config, dataframes=[input_df])) + pipe.add_stage( + DeserializeStage(config, message_type=ControlMessage, task_type="llm_engine", task_payload=task_payload)) + pipe.add_stage(LLMEngineStage(config, engine=_build_engine(mock_llm_client=mock_llm_client))) + sink = pipe.add_stage(CompareDataFrameStage(config, compare_df=expected_df)) + + pipe.run() + + assert_results(sink.get_results()) \ No newline at end of file From 8e7b7164e0fd80e263a2a55caff1bb008cff9a61 Mon Sep 17 00:00:00 2001 From: David Gardner Date: Fri, 27 Oct 2023 10:04:46 -0700 Subject: [PATCH 26/26] Formatting --- tests/llm/nodes/test_llm_generate_node_pipe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/llm/nodes/test_llm_generate_node_pipe.py b/tests/llm/nodes/test_llm_generate_node_pipe.py index 61e537c648..cbc8b1f66b 100644 --- a/tests/llm/nodes/test_llm_generate_node_pipe.py +++ b/tests/llm/nodes/test_llm_generate_node_pipe.py @@ -63,4 +63,4 @@ def test_pipeline(config: Config, mock_llm_client: mock.MagicMock): pipe.run() - assert_results(sink.get_results()) \ No newline at end of file + assert_results(sink.get_results())