Skip to content

Commit

Permalink
Skip langchain test for less than 3.7
Browse files Browse the repository at this point in the history
  • Loading branch information
aravind-segu committed Oct 4, 2024
1 parent 2197174 commit 8e66831
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
16 changes: 14 additions & 2 deletions databricks/sdk/mixins/open_ai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,30 @@ def get_open_ai_client(self):
except Exception:
raise ValueError("Unable to extract authorization token for OpenAI Client")

from openai import OpenAI
try:
from openai import OpenAI
except Exception:
raise ImportError(
"Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]`"
)

return OpenAI(base_url=self._api._cfg.host + "/serving-endpoints", api_key=token)

def get_langchain_chat_open_ai_client(self, model):
auth_headers = self._api._cfg.authenticate()

try:
from langchain_openai import ChatOpenAI
except Exception:
raise ImportError(
"Langchain Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]` and ensure you are using python>3.7"
)

try:
token = auth_headers["Authorization"][len("Bearer "):]
except Exception:
raise ValueError("Unable to extract authorization token for Langchain OpenAI Client")

from langchain_openai import ChatOpenAI
return ChatOpenAI(model=model,
openai_api_base=self._api._cfg.host + "/serving-endpoints",
openai_api_key=token)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"databricks-connect", "pytest-rerunfailures", "openai",
"langchain-openai"],
"notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"],
"openai": ["openai", "langchain-openai"]},
"openai": ["openai", 'langchain-openai; python_version > "3.7"']},
author="Serge Smertin",
author_email="serge.smertin@databricks.com",
description="Databricks SDK for Python (Beta)",
Expand Down
7 changes: 6 additions & 1 deletion tests/test_open_ai_mixin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import sys

import pytest

from databricks.sdk.core import Config


Expand All @@ -13,6 +17,7 @@ def test_open_ai_client(monkeypatch):
assert client.api_key == "test_token"


@pytest.mark.skipif(sys.version_info <= (3, 7), reason="Requires Python > 3.7")
def test_langchain_open_ai_client(monkeypatch):
from databricks.sdk import WorkspaceClient

Expand All @@ -22,4 +27,4 @@ def test_langchain_open_ai_client(monkeypatch):
client = w.serving_endpoints.get_langchain_chat_open_ai_client("databricks-meta-llama-3-1-70b-instruct")

assert client.openai_api_base == "https://test_host/serving-endpoints"
assert client.model_name == "databricks-meta-llama-3-1-70b-instruct"
assert client.model_name == "databricks-meta-llama-3-1-70b-instruct"

0 comments on commit 8e66831

Please sign in to comment.