From 8e668316008606b1e30f8d6a331ace2d6ac6bde4 Mon Sep 17 00:00:00 2001 From: aravind-segu Date: Thu, 3 Oct 2024 21:41:53 -0700 Subject: [PATCH] Skip langchain test for less than 3.7 --- databricks/sdk/mixins/open_ai_client.py | 16 ++++++++++++++-- setup.py | 2 +- tests/test_open_ai_mixin.py | 7 ++++++- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/databricks/sdk/mixins/open_ai_client.py b/databricks/sdk/mixins/open_ai_client.py index 5bd61268..084983ca 100644 --- a/databricks/sdk/mixins/open_ai_client.py +++ b/databricks/sdk/mixins/open_ai_client.py @@ -11,18 +11,30 @@ def get_open_ai_client(self): except Exception: raise ValueError("Unable to extract authorization token for OpenAI Client") - from openai import OpenAI + try: + from openai import OpenAI + except Exception: + raise ImportError( + "Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]`" + ) + return OpenAI(base_url=self._api._cfg.host + "/serving-endpoints", api_key=token) def get_langchain_chat_open_ai_client(self, model): auth_headers = self._api._cfg.authenticate() + try: + from langchain_openai import ChatOpenAI + except Exception: + raise ImportError( + "Langchain Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]` and ensure you are using python>3.7" + ) + try: token = auth_headers["Authorization"][len("Bearer "):] except Exception: raise ValueError("Unable to extract authorization token for Langchain OpenAI Client") - from langchain_openai import ChatOpenAI return ChatOpenAI(model=model, openai_api_base=self._api._cfg.host + "/serving-endpoints", openai_api_key=token) diff --git a/setup.py b/setup.py index 03e29a3b..51dcd844 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ "databricks-connect", "pytest-rerunfailures", "openai", "langchain-openai"], "notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"], - "openai": ["openai", "langchain-openai"]}, + "openai": ["openai", 'langchain-openai; python_version > "3.7"']}, author="Serge Smertin", author_email="serge.smertin@databricks.com", description="Databricks SDK for Python (Beta)", diff --git a/tests/test_open_ai_mixin.py b/tests/test_open_ai_mixin.py index abe1c20f..2d184267 100644 --- a/tests/test_open_ai_mixin.py +++ b/tests/test_open_ai_mixin.py @@ -1,3 +1,7 @@ +import sys + +import pytest + from databricks.sdk.core import Config @@ -13,6 +17,7 @@ def test_open_ai_client(monkeypatch): assert client.api_key == "test_token" +@pytest.mark.skipif(sys.version_info <= (3, 7), reason="Requires Python > 3.7") def test_langchain_open_ai_client(monkeypatch): from databricks.sdk import WorkspaceClient @@ -22,4 +27,4 @@ def test_langchain_open_ai_client(monkeypatch): client = w.serving_endpoints.get_langchain_chat_open_ai_client("databricks-meta-llama-3-1-70b-instruct") assert client.openai_api_base == "https://test_host/serving-endpoints" - assert client.model_name == "databricks-meta-llama-3-1-70b-instruct" \ No newline at end of file + assert client.model_name == "databricks-meta-llama-3-1-70b-instruct"