diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-openai/llama_index/embeddings/openai/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-openai/llama_index/embeddings/openai/base.py index b6c29c6ac540c..d01a429b03b39 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-openai/llama_index/embeddings/openai/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-openai/llama_index/embeddings/openai/base.py @@ -280,6 +280,7 @@ class OpenAIEmbedding(BaseEmbedding): _client: Optional[OpenAI] = PrivateAttr() _aclient: Optional[AsyncOpenAI] = PrivateAttr() _http_client: Optional[httpx.Client] = PrivateAttr() + _async_http_client: Optional[httpx.AsyncClient] = PrivateAttr() def __init__( self, @@ -297,6 +298,7 @@ def __init__( callback_manager: Optional[CallbackManager] = None, default_headers: Optional[Dict[str, str]] = None, http_client: Optional[httpx.Client] = None, + async_http_client: Optional[httpx.AsyncClient] = None, num_workers: Optional[int] = None, **kwargs: Any, ) -> None: @@ -339,6 +341,7 @@ def __init__( self._client = None self._aclient = None self._http_client = http_client + self._async_http_client = async_http_client def _resolve_credentials( self, @@ -358,24 +361,24 @@ def _get_client(self) -> OpenAI: def _get_aclient(self) -> AsyncOpenAI: if not self.reuse_client: - return AsyncOpenAI(**self._get_credential_kwargs()) + return AsyncOpenAI(**self._get_credential_kwargs(is_async=True)) if self._aclient is None: - self._aclient = AsyncOpenAI(**self._get_credential_kwargs()) + self._aclient = AsyncOpenAI(**self._get_credential_kwargs(is_async=True)) return self._aclient @classmethod def class_name(cls) -> str: return "OpenAIEmbedding" - def _get_credential_kwargs(self) -> Dict[str, Any]: + def _get_credential_kwargs(self, is_async: bool = False) -> Dict[str, Any]: return { "api_key": self.api_key, "base_url": self.api_base, "max_retries": self.max_retries, "timeout": self.timeout, "default_headers": self.default_headers, - "http_client": self._http_client, + "http_client": self._async_http_client if is_async else self._http_client, } def _get_query_embedding(self, query: str) -> List[float]: diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-openai/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-openai/pyproject.toml index 417a1fbe20774..74b6846e8e2ac 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-openai/pyproject.toml +++ b/llama-index-integrations/embeddings/llama-index-embeddings-openai/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-embeddings-openai" readme = "README.md" -version = "0.1.10" +version = "0.1.11" [tool.poetry.dependencies] python = ">=3.8.1,<4.0"