Skip to content

Commit

Permalink
[feat] PineconeGrpcFuture implements concurrent.futures.Future (#410)
Browse files Browse the repository at this point in the history
## Problem

`GRPCIndex` has long had limited and poorly documented support for async
operations via the futures interface of the
[`grpc`](https://grpc.github.io/grpc/python/grpc.html#grpc.UnaryUnaryMultiCallable.future)
library. I've recently been trying to implement `query_namespaces` using
these futures, and discovered that unfortunately the grpc futures
implementation is not compatible with the `concurrent.futures` package
in the standard library. This makes them pretty much useless for
anything at all complicated because the grpc library doesn't provide any
utils for synchronization or waiting.

## Solution

A class called `PineconeGrpcFuture` was added in the past as a minimal
wrapper around the
[future](https://grpc.github.io/grpc/python/grpc.html#future-interfaces)
that is emitted by `grpc`. These futures objects are used to represent
asynchronous computation, and allow you to regisiter callbacks with
`add_done_callback`. This is similar to calling `then()` on a javascript
promise.

The original purpose of our `PineconeGrpcFuture` wrapper class seems to
have been to implement some basic (very basic) error mapping, but for
this diff I decided to extend the class to implement the
`concurrent.futures.Future` interface. This allows the instances of
`PineconeGrpcFuture` to be used with `concurrent.futures.as_completed`
and `concurrent.futures.wait` utilities, which makes them dramatically
more ergonomic to deal with. Unfortunately the grpc future is not
compatible with the `concurrent.future` package out of the box.

For the unit tests of `PineconeGrpcFuture`, I had to make heavy use of
mocking because all the various grpc classes are tightly coupled and
can't be simply setup without performing actual network calls. This
doesn't give me huge confidence it's actually working as expected, so as
a sanity check I added some additional integration test coverage for
`upsert`, `fetch`, and `delete` using `concurrent.futures.wait`.

## Type of Change

- [x] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [x] This change requires a documentation update
- [ ] Infrastructure change (CI configs, etc)
- [ ] Non-code change (docs, etc)
- [ ] None of the above: (explain here)

## Test Plan

Added unit and integration tests
  • Loading branch information
jhamon authored Nov 4, 2024
1 parent 36373a1 commit 463c30d
Show file tree
Hide file tree
Showing 9 changed files with 885 additions and 32 deletions.
4 changes: 4 additions & 0 deletions pinecone/grpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,24 @@
from .index_grpc import GRPCIndex
from .pinecone import PineconeGRPC
from .config import GRPCClientConfig
from .future import PineconeGrpcFuture

from pinecone.core.grpc.protos.vector_service_pb2 import (
Vector as GRPCVector,
SparseValues as GRPCSparseValues,
Vector,
SparseValues,
DeleteResponse as GRPCDeleteResponse,
)

__all__ = [
"GRPCIndex",
"PineconeGRPC",
"GRPCDeleteResponse",
"GRPCClientConfig",
"GRPCVector",
"GRPCSparseValues",
"Vector",
"SparseValues",
"PineconeGrpcFuture",
]
94 changes: 73 additions & 21 deletions pinecone/grpc/future.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,86 @@
from grpc._channel import _MultiThreadedRendezvous
from concurrent.futures import Future as ConcurrentFuture
from typing import Optional
from grpc import Future as GrpcFuture, RpcError
from pinecone.exceptions.exceptions import PineconeException


class PineconeGrpcFuture:
def __init__(self, delegate):
self._delegate = delegate
class PineconeGrpcFuture(ConcurrentFuture):
def __init__(
self, grpc_future: GrpcFuture, timeout: Optional[int] = None, result_transformer=None
):
super().__init__()
self._grpc_future = grpc_future
self._result_transformer = result_transformer
if timeout is not None:
self._default_timeout = timeout # seconds
else:
self._default_timeout = 5 # seconds

def cancel(self):
return self._delegate.cancel()
# Sync initial state, in case the gRPC future is already done
self._sync_state(self._grpc_future)

def cancelled(self):
return self._delegate.cancelled()
# Add callback to subscribe to updates from the gRPC future
self._grpc_future.add_done_callback(self._sync_state)

def running(self):
return self._delegate.running()
@property
def grpc_future(self):
return self._grpc_future

def done(self):
return self._delegate.done()
def _sync_state(self, grpc_future):
if self.done():
return

def add_done_callback(self, fun):
return self._delegate.add_done_callback(fun)
if grpc_future.cancelled():
self.cancel()
elif grpc_future.exception(timeout=self._default_timeout):
self.set_exception(grpc_future.exception())
elif grpc_future.done():
try:
result = grpc_future.result(timeout=self._default_timeout)
self.set_result(result)
except Exception as e:
self.set_exception(e)
elif grpc_future.running():
self.set_running_or_notify_cancel()

def result(self, timeout=None):
try:
return self._delegate.result(timeout=timeout)
except _MultiThreadedRendezvous as e:
raise PineconeException(e._state.debug_error_string) from e
def set_result(self, result):
if self._result_transformer:
result = self._result_transformer(result)
return super().set_result(result)

def cancel(self):
self._grpc_future.cancel()
return super().cancel()

def exception(self, timeout=None):
return self._delegate.exception(timeout=timeout)
exception = super().exception(timeout=self._timeout(timeout))
if isinstance(exception, RpcError):
return self._wrap_rpc_exception(exception)
return exception

def traceback(self, timeout=None):
return self._delegate.traceback(timeout=timeout)
# This is not part of the ConcurrentFuture interface, but keeping it for
# backward compatibility
return self._grpc_future.traceback(timeout=self._timeout(timeout))

def result(self, timeout=None):
try:
return super().result(timeout=self._timeout(timeout))
except RpcError as e:
raise self._wrap_rpc_exception(e) from e

def _timeout(self, timeout: Optional[int] = None) -> int:
if timeout is not None:
return timeout
else:
return self._default_timeout

def _wrap_rpc_exception(self, e):
if e._state and e._state.debug_error_string:
return PineconeException(e._state.debug_error_string)
else:
return PineconeException("Unknown GRPC error")

def __del__(self):
self._grpc_future.cancel()
self = None # release the reference to the grpc future
18 changes: 13 additions & 5 deletions pinecone/grpc/index_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,12 @@ def delete(
return self.runner.run(self.stub.Delete, request, timeout=timeout)

def fetch(
self, ids: Optional[List[str]], namespace: Optional[str] = None, **kwargs
) -> FetchResponse:
self,
ids: Optional[List[str]],
namespace: Optional[str] = None,
async_req: Optional[bool] = False,
**kwargs,
) -> Union[FetchResponse, PineconeGrpcFuture]:
"""
The fetch operation looks up and returns vectors, by ID, from a single namespace.
The returned vectors include the vector data and/or metadata.
Expand All @@ -304,9 +308,13 @@ def fetch(
args_dict = self._parse_non_empty_args([("namespace", namespace)])

request = FetchRequest(ids=ids, **args_dict, **kwargs)
response = self.runner.run(self.stub.Fetch, request, timeout=timeout)
json_response = json_format.MessageToDict(response)
return parse_fetch_response(json_response)

if async_req:
future = self.runner.run(self.stub.Fetch.future, request, timeout=timeout)
return PineconeGrpcFuture(future, result_transformer=parse_fetch_response)
else:
response = self.runner.run(self.stub.Fetch, request, timeout=timeout)
return parse_fetch_response(response)

def query(
self,
Expand Down
13 changes: 9 additions & 4 deletions pinecone/grpc/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Optional
from google.protobuf import json_format
from google.protobuf.message import Message

import uuid

from pinecone.core.openapi.data.models import (
Expand Down Expand Up @@ -35,10 +38,12 @@ def parse_sparse_values(sparse_values: dict):
)


def parse_fetch_response(response: dict):
def parse_fetch_response(response: Message):
json_response = json_format.MessageToDict(response)

vd = {}
vectors = response.get("vectors", {})
namespace = response.get("namespace", "")
vectors = json_response.get("vectors", {})
namespace = json_response.get("namespace", "")

for id, vec in vectors.items():
vd[id] = _Vector(
Expand All @@ -52,7 +57,7 @@ def parse_fetch_response(response: dict):
return FetchResponse(
vectors=vd,
namespace=namespace,
usage=parse_usage(response.get("usage", {})),
usage=parse_usage(json_response.get("usage", {})),
_check_type=False,
)

Expand Down
34 changes: 34 additions & 0 deletions tests/integration/data/test_delete_future.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
import pytest
from pinecone import Vector
from ..helpers import poll_stats_for_namespace, random_string

if os.environ.get("USE_GRPC") == "true":
from pinecone.grpc import GRPCDeleteResponse


class TestDeleteFuture:
@pytest.mark.skipif(
os.getenv("USE_GRPC") != "true", reason="PineconeGrpcFutures only returned from grpc client"
)
def test_delete_future(self, idx):
namespace = random_string(10)

idx.upsert(
vectors=[
Vector(id="id1", values=[0.1, 0.2]),
Vector(id="id2", values=[0.1, 0.2]),
Vector(id="id3", values=[0.1, 0.2]),
],
namespace=namespace,
)
poll_stats_for_namespace(idx, namespace, 3)

delete_one = idx.delete(ids=["id1"], namespace=namespace, async_req=True)
delete_namespace = idx.delete(namespace=namespace, delete_all=True, async_req=True)

from concurrent.futures import as_completed

for future in as_completed([delete_one, delete_namespace], timeout=10):
resp = future.result()
assert isinstance(resp, GRPCDeleteResponse)
101 changes: 101 additions & 0 deletions tests/integration/data/test_fetch_future.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os
import pytest

if os.environ.get("USE_GRPC") == "true":
from pinecone.grpc import PineconeGrpcFuture


@pytest.mark.skipif(
os.getenv("USE_GRPC") != "true", reason="PineconeGrpcFutures only returned from grpc client"
)
class TestFetchFuture:
def setup_method(self):
self.expected_dimension = 2

def test_fetch_multiple_by_id(self, idx, namespace):
target_namespace = namespace

results = idx.fetch(ids=["1", "2", "4"], namespace=target_namespace, async_req=True)
assert isinstance(results, PineconeGrpcFuture)

from concurrent.futures import wait, FIRST_COMPLETED

done, _ = wait([results], return_when=FIRST_COMPLETED)

results = done.pop().result()
assert results.usage is not None
assert results.usage["read_units"] is not None
assert results.usage["read_units"] > 0

assert results.namespace == target_namespace
assert len(results.vectors) == 3
assert results.vectors["1"].id == "1"
assert results.vectors["2"].id == "2"
# Metadata included, if set
assert results.vectors["1"].metadata is None
assert results.vectors["2"].metadata is None
assert results.vectors["4"].metadata is not None
assert results.vectors["4"].metadata["genre"] == "action"
assert results.vectors["4"].metadata["runtime"] == 120
# Values included
assert results.vectors["1"].values is not None
assert len(results.vectors["1"].values) == self.expected_dimension

def test_fetch_single_by_id(self, idx, namespace):
target_namespace = namespace

future = idx.fetch(ids=["1"], namespace=target_namespace, async_req=True)

from concurrent.futures import wait, FIRST_COMPLETED

done, _ = wait([future], return_when=FIRST_COMPLETED)
results = done.pop().result()

assert results.namespace == target_namespace
assert len(results.vectors) == 1
assert results.vectors["1"].id == "1"
assert results.vectors["1"].metadata is None
assert results.vectors["1"].values is not None
assert len(results.vectors["1"].values) == self.expected_dimension

def test_fetch_nonexistent_id(self, idx, namespace):
target_namespace = namespace

# Fetch id that is missing
future = idx.fetch(ids=["100"], namespace=target_namespace, async_req=True)

from concurrent.futures import wait, FIRST_COMPLETED

done, _ = wait([future], return_when=FIRST_COMPLETED)
results = done.pop().result()

assert results.namespace == target_namespace
assert len(results.vectors) == 0

def test_fetch_nonexistent_namespace(self, idx):
target_namespace = "nonexistent-namespace"

# Fetch from namespace with no vectors
future = idx.fetch(ids=["1"], namespace=target_namespace, async_req=True)

from concurrent.futures import wait, FIRST_COMPLETED

done, _ = wait([future], return_when=FIRST_COMPLETED)
results = done.pop().result()

assert results.namespace == target_namespace
assert len(results.vectors) == 0

def test_fetch_unspecified_namespace(self, idx):
# Fetch without specifying namespace gives default namespace results
future = idx.fetch(ids=["1", "4"], async_req=True)

from concurrent.futures import wait, FIRST_COMPLETED

done, _ = wait([future], return_when=FIRST_COMPLETED)
results = done.pop().result()

assert results.namespace == ""
assert results.vectors["1"].id == "1"
assert results.vectors["1"].values is not None
assert results.vectors["4"].metadata is not None
Loading

0 comments on commit 463c30d

Please sign in to comment.