Skip to content

Commit

Permalink
Fix OpenAI Embedding async client bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jials committed Jul 19, 2024
1 parent c817526 commit 1fb30e9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ class OpenAIEmbedding(BaseEmbedding):
_client: Optional[OpenAI] = PrivateAttr()
_aclient: Optional[AsyncOpenAI] = PrivateAttr()
_http_client: Optional[httpx.Client] = PrivateAttr()
_async_http_client: Optional[httpx.AsyncClient] = PrivateAttr()

def __init__(
self,
Expand All @@ -297,6 +298,7 @@ def __init__(
callback_manager: Optional[CallbackManager] = None,
default_headers: Optional[Dict[str, str]] = None,
http_client: Optional[httpx.Client] = None,
async_http_client: Optional[httpx.AsyncClient] = None,
num_workers: Optional[int] = None,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -339,6 +341,7 @@ def __init__(
self._client = None
self._aclient = None
self._http_client = http_client
self._async_http_client = async_http_client

def _resolve_credentials(
self,
Expand All @@ -358,24 +361,24 @@ def _get_client(self) -> OpenAI:

def _get_aclient(self) -> AsyncOpenAI:
if not self.reuse_client:
return AsyncOpenAI(**self._get_credential_kwargs())
return AsyncOpenAI(**self._get_credential_kwargs(is_async=True))

if self._aclient is None:
self._aclient = AsyncOpenAI(**self._get_credential_kwargs())
self._aclient = AsyncOpenAI(**self._get_credential_kwargs(is_async=True))
return self._aclient

@classmethod
def class_name(cls) -> str:
return "OpenAIEmbedding"

def _get_credential_kwargs(self) -> Dict[str, Any]:
def _get_credential_kwargs(self, is_async: bool = False) -> Dict[str, Any]:
return {
"api_key": self.api_key,
"base_url": self.api_base,
"max_retries": self.max_retries,
"timeout": self.timeout,
"default_headers": self.default_headers,
"http_client": self._http_client,
"http_client": self._async_http_client if is_async else self._http_client,
}

def _get_query_embedding(self, query: str) -> List[float]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-embeddings-openai"
readme = "README.md"
version = "0.1.10"
version = "0.1.11"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down

0 comments on commit 1fb30e9

Please sign in to comment.