From bace64741b790a779c81e52827d7e20c788deaef Mon Sep 17 00:00:00 2001 From: Sajith Ariyarathna Date: Thu, 2 Nov 2023 06:34:59 +0530 Subject: [PATCH] Add `client.server` API to the Client (#1036) * Add client.server API * Format code --- client/h2ogpt_client/_core.py | 11 +++++++++-- client/h2ogpt_client/_server.py | 18 ++++++++++++++++++ client/tests/test_client.py | 5 +++++ 3 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 client/h2ogpt_client/_server.py diff --git a/client/h2ogpt_client/_core.py b/client/h2ogpt_client/_core.py index 0e1ffdfca..d3b0a0e33 100644 --- a/client/h2ogpt_client/_core.py +++ b/client/h2ogpt_client/_core.py @@ -12,6 +12,7 @@ PromptType, ) from h2ogpt_client._models import Model, Models +from h2ogpt_client._server import Server class Client: @@ -38,6 +39,7 @@ def __init__( self._text_completion = TextCompletionCreator(self) self._chat_completion = ChatCompletionCreator(self) self._models = Models(self) + self._server = Server(self) @property def text_completion(self) -> "TextCompletionCreator": @@ -50,10 +52,15 @@ def chat_completion(self) -> "ChatCompletionCreator": return self._chat_completion @property - def models(self) -> "Models": - """LL models""" + def models(self) -> Models: + """LL models.""" return self._models + @property + def server(self) -> Server: + """h2oGPT server.""" + return self._server + def _predict(self, *args, api_name: str) -> Any: return self._client.submit(*args, api_name=api_name).result() diff --git a/client/h2ogpt_client/_server.py b/client/h2ogpt_client/_server.py new file mode 100644 index 000000000..c00ec1c65 --- /dev/null +++ b/client/h2ogpt_client/_server.py @@ -0,0 +1,18 @@ +from h2ogpt_client import _core + + +class Server: + """h2oGPT server.""" + + def __init__(self, client: "_core.Client"): + self._client = client + + @property + def address(self) -> str: + """h2oGPT server address.""" + return self._client._client.src + + @property + def hash(self) -> str: + """h2oGPT server system hash.""" + return str(self._client._predict(api_name="/system_hash")) diff --git a/client/tests/test_client.py b/client/tests/test_client.py index 7be0ec650..f26c5f142 100644 --- a/client/tests/test_client.py +++ b/client/tests/test_client.py @@ -84,6 +84,11 @@ def test_available_models(client): print(models) +def test_server_properties(client, server_url): + assert client.server.address.startswith(server_url) + assert client.server.hash + + def test_parameters_order(client, eval_func_param_names): text_completion = client.text_completion.create() assert eval_func_param_names == list(text_completion._parameters.keys())