Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pinecone/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

import logging

# Raise an exception if the user is attempting to use the SDK with deprecated plugins
# installed in their project.
# Raise an exception if the user is attempting to use the SDK with
# deprecated plugins installed in their project.
check_for_deprecated_plugins()

# Silence annoying log messages from the plugin interface
Expand Down
21 changes: 2 additions & 19 deletions pinecone/control/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pinecone.openapi_support.api_client import ApiClient


from pinecone.utils import normalize_host, setup_openapi_client, build_plugin_setup_client
from pinecone.utils import normalize_host, setup_openapi_client, PluginAware
from pinecone.core.openapi.db_control import API_VERSION
from pinecone.models import (
ServerlessSpec,
Expand All @@ -38,13 +38,11 @@
from .types import CreateIndexForModelEmbedTypedDict
from .request_factory import PineconeDBControlRequestFactory

from pinecone_plugin_interface import load_and_install as install_plugins

logger = logging.getLogger(__name__)
""" @private """


class Pinecone(PineconeDBControlInterface):
class Pinecone(PineconeDBControlInterface, PluginAware):
"""
A client for interacting with Pinecone's vector database.

Expand Down Expand Up @@ -113,21 +111,6 @@ def inference(self):
self._inference = _Inference(config=self.config, openapi_config=self.openapi_config)
return self._inference

def load_plugins(self):
"""@private"""
try:
# I don't expect this to ever throw, but wrapping this in a
# try block just in case to make sure a bad plugin doesn't
# halt client initialization.
openapi_client_builder = build_plugin_setup_client(
config=self.config,
openapi_config=self.openapi_config,
pool_threads=self.pool_threads,
)
install_plugins(self, openapi_client_builder)
except Exception as e:
logger.error(f"Error loading plugins: {e}")

def create_index(
self,
name: str,
Expand Down
21 changes: 1 addition & 20 deletions pinecone/control/pinecone_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pinecone.core.openapi.db_control.api.manage_indexes_api import AsyncioManageIndexesApi
from pinecone.openapi_support import AsyncioApiClient

from pinecone.utils import normalize_host, setup_openapi_client, build_plugin_setup_client
from pinecone.utils import normalize_host, setup_openapi_client
from pinecone.core.openapi.db_control import API_VERSION
from pinecone.models import (
ServerlessSpec,
Expand All @@ -36,8 +36,6 @@
from .request_factory import PineconeDBControlRequestFactory
from .pinecone_interface_asyncio import PineconeAsyncioDBControlInterface

from pinecone_plugin_interface import load_and_install as install_plugins

logger = logging.getLogger(__name__)
""" @private """

Expand Down Expand Up @@ -104,8 +102,6 @@ def __init__(
self.index_host_store = IndexHostStore()
""" @private """

self.load_plugins()

async def __aenter__(self):
return self

Expand All @@ -122,21 +118,6 @@ def inference(self):
self._inference = _AsyncioInference(api_client=self.index_api.api_client)
return self._inference

def load_plugins(self):
"""@private"""
try:
# I don't expect this to ever throw, but wrapping this in a
# try block just in case to make sure a bad plugin doesn't
# halt client initialization.
openapi_client_builder = build_plugin_setup_client(
config=self.config,
openapi_config=self.openapi_config,
pool_threads=self.pool_threads,
)
install_plugins(self, openapi_client_builder)
except Exception as e:
logger.error(f"Error loading plugins: {e}")

async def create_index(
self,
name: str,
Expand Down
21 changes: 2 additions & 19 deletions pinecone/data/features/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from pinecone.core.openapi.inference.apis import InferenceApi
from pinecone.core.openapi.inference.models import EmbeddingsList, RerankResult
from pinecone.core.openapi.inference import API_VERSION
from pinecone.utils import setup_openapi_client, build_plugin_setup_client
from pinecone.utils import setup_openapi_client, PluginAware

from pinecone_plugin_interface import load_and_install as install_plugins

from .inference_request_builder import (
InferenceRequestBuilder,
Expand All @@ -18,7 +17,7 @@
logger = logging.getLogger(__name__)


class Inference:
class Inference(PluginAware):
"""
The `Inference` class configures and uses the Pinecone Inference API to generate embeddings and
rank documents.
Expand All @@ -43,24 +42,8 @@ def __init__(self, config, openapi_config, **kwargs):
pool_threads=kwargs.get("pool_threads", 1),
api_version=API_VERSION,
)

self.load_plugins()

def load_plugins(self):
"""@private"""
try:
# I don't expect this to ever throw, but wrapping this in a
# try block just in case to make sure a bad plugin doesn't
# halt client initialization.
openapi_client_builder = build_plugin_setup_client(
config=self.config,
openapi_config=self.openapi_config,
pool_threads=self.pool_threads,
)
install_plugins(self, openapi_client_builder)
except Exception as e:
logger.error(f"Error loading plugins: {e}")

def embed(
self,
model: Union[EmbedModelEnum, str],
Expand Down
30 changes: 7 additions & 23 deletions pinecone/data/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@
from ..utils import (
setup_openapi_client,
parse_non_empty_args,
build_plugin_setup_client,
validate_and_convert_errors,
PluginAware,
)
from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults
from pinecone.openapi_support import OPENAPI_ENDPOINT_PARAMS

from multiprocessing.pool import ApplyResult
from concurrent.futures import as_completed

from pinecone_plugin_interface import load_and_install as install_plugins

logger = logging.getLogger(__name__)

Expand All @@ -52,7 +51,7 @@ def parse_query_response(response: QueryResponse):
return response


class Index(IndexInterface, ImportFeatureMixin):
class Index(IndexInterface, ImportFeatureMixin, PluginAware):
"""
A client for interacting with a Pinecone index via REST API.
For improved performance, use the Pinecone GRPC index client.
Expand All @@ -70,17 +69,17 @@ def __init__(
self.config = ConfigBuilder.build(
api_key=api_key, host=host, additional_headers=additional_headers, **kwargs
)
self._openapi_config = ConfigBuilder.build_openapi_config(self.config, openapi_config)
self._pool_threads = pool_threads
self.openapi_config = ConfigBuilder.build_openapi_config(self.config, openapi_config)
self.pool_threads = pool_threads

if kwargs.get("connection_pool_maxsize", None):
self._openapi_config.connection_pool_maxsize = kwargs.get("connection_pool_maxsize")
self.openapi_config.connection_pool_maxsize = kwargs.get("connection_pool_maxsize")

self._vector_api = setup_openapi_client(
api_client_klass=ApiClient,
api_klass=VectorOperationsApi,
config=self.config,
openapi_config=self._openapi_config,
openapi_config=self.openapi_config,
pool_threads=pool_threads,
api_version=API_VERSION,
)
Expand All @@ -90,22 +89,7 @@ def __init__(
# Pass the same api_client to the ImportFeatureMixin
super().__init__(api_client=self._api_client)

self._load_plugins()

def _load_plugins(self):
"""@private"""
try:
# I don't expect this to ever throw, but wrapping this in a
# try block just in case to make sure a bad plugin doesn't
# halt client initialization.
openapi_client_builder = build_plugin_setup_client(
config=self.config,
openapi_config=self._openapi_config,
pool_threads=self._pool_threads,
)
install_plugins(self, openapi_client_builder)
except Exception as e:
logger.error(f"Error loading plugins in Index: {e}")
self.load_plugins()

def _openapi_kwargs(self, kwargs):
return {k: v for k, v in kwargs.items() if k in OPENAPI_ENDPOINT_PARAMS}
Expand Down
26 changes: 2 additions & 24 deletions pinecone/data/index_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@
SearchRecordsResponse,
)

from ..utils import (
setup_openapi_client,
parse_non_empty_args,
build_plugin_setup_client,
validate_and_convert_errors,
)
from ..utils import setup_openapi_client, parse_non_empty_args, validate_and_convert_errors
from .types import (
SparseVectorTypedDict,
VectorTypedDict,
Expand All @@ -47,7 +42,7 @@
from .vector_factory import VectorFactory
from .query_results_aggregator import QueryNamespacesResults
from .features.bulk_import import ImportFeatureMixinAsyncio
from pinecone_plugin_interface import load_and_install as install_plugins


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -107,23 +102,6 @@ def __init__(
# This is important for async context management to work correctly
super().__init__(api_client=self._api_client)

self._load_plugins()

def _load_plugins(self):
"""@private"""
try:
# I don't expect this to ever throw, but wrapping this in a
# try block just in case to make sure a bad plugin doesn't
# halt client initialization.
openapi_client_builder = build_plugin_setup_client(
config=self.config,
openapi_config=self._openapi_config,
pool_threads=self._pool_threads,
)
install_plugins(self, openapi_client_builder)
except Exception as e:
logger.error(f"Error loading plugins in Index: {e}")

async def __aenter__(self):
return self

Expand Down
18 changes: 0 additions & 18 deletions pinecone/grpc/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
from typing import Optional

import logging
import grpc
from grpc._channel import Channel

Expand All @@ -12,10 +11,6 @@
from .grpc_runner import GrpcRunner
from concurrent.futures import ThreadPoolExecutor

from pinecone_plugin_interface import load_and_install as install_plugins

_logger = logging.getLogger(__name__)


class GRPCIndexBase(ABC):
"""
Expand Down Expand Up @@ -48,19 +43,6 @@ def __init__(
self._channel = channel or self._gen_channel()
self.stub = self.stub_class(self._channel)

self._load_plugins()

def _load_plugins(self):
"""@private"""
try:

def stub_openapi_client_builder(*args, **kwargs):
pass

install_plugins(self, stub_openapi_client_builder)
except Exception as e:
_logger.error(f"Error loading plugins in GRPCIndex: {e}")

@property
def threadpool_executor(self):
if self._pool is None:
Expand Down
2 changes: 2 additions & 0 deletions pinecone/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from .docslinks import docslinks
from .repr_overrides import install_json_repr_override
from .error_handling import validate_and_convert_errors
from .plugin_aware import PluginAware

__all__ = [
"PluginAware",
"check_kwargs",
"__version__",
"get_user_agent",
Expand Down
22 changes: 22 additions & 0 deletions pinecone/utils/plugin_aware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from .setup_openapi_client import build_plugin_setup_client
from pinecone_plugin_interface import load_and_install as install_plugins
import logging

logger = logging.getLogger(__name__)


class PluginAware:
def load_plugins(self):
"""@private"""
try:
# I don't expect this to ever throw, but wrapping this in a
# try block just in case to make sure a bad plugin doesn't
# halt client initialization.
openapi_client_builder = build_plugin_setup_client(
config=self.config,
openapi_config=self.openapi_config,
pool_threads=self.pool_threads,
)
install_plugins(self, openapi_client_builder)
except Exception as e:
logger.error(f"Error loading plugins: {e}")
13 changes: 3 additions & 10 deletions tests/unit/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
ServerlessSpec as ServerlessSpecOpenApi,
IndexModelStatus,
)
from pinecone.utils import PluginAware

from pinecone.core.openapi.db_control.api.manage_indexes_api import ManageIndexesApi

import time
Expand Down Expand Up @@ -78,19 +80,10 @@ def index_list_response():

class TestControl:
def test_plugins_are_installed(self):
with patch("pinecone.control.pinecone.install_plugins") as mock_install_plugins:
with patch.object(PluginAware, "load_plugins") as mock_install_plugins:
Pinecone(api_key="asdf")
mock_install_plugins.assert_called_once()

def test_bad_plugin_doesnt_break_sdk(self):
with patch(
"pinecone.control.pinecone.install_plugins", side_effect=Exception("bad plugin")
):
try:
Pinecone(api_key="asdf")
except Exception as e:
assert False, f"Unexpected exception: {e}"

def test_default_host(self):
p = Pinecone(api_key="123-456-789")
assert p.index_api.api_client.configuration.host == "https://api.pinecone.io"
Expand Down
Loading