Skip to content

Stream client prototype #2: asyncio-heavy #6145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 209 additions & 0 deletions cirq-google/cirq_google/engine/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.

import asyncio
from asyncio.log import logger
import datetime
import sys
import threading
from typing import (
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Dict,
Expand All @@ -42,6 +44,8 @@

_M = TypeVar('_M', bound=proto.Message)
_R = TypeVar('_R')
JobPath = str
MessageId = str


class EngineException(Exception):
Expand Down Expand Up @@ -95,6 +99,146 @@ def instance(cls):
return cls._instance


class ResponseDemux:
"""A event demultiplexer for QuantumRunStreamResponses, as part of the async reactor pattern."""

def __init__(self):
self._subscribers: Dict[MessageId, asyncio.Future] = {}
self._next_available_message_id = 0

def subscribe(self, request: quantum.QuantumRunStreamRequest) -> asyncio.Future:
"""Assumes the message ID has not been set."""
request.message_id = str(self._next_available_message_id)
response_future: asyncio.Future = asyncio.get_running_loop().create_future()
self._subscribers[request.message_id] = response_future
self._next_available_message_id += 1
return response_future

def unsubscribe(self, request: quantum.QuantumRunStreamRequest) -> None:
if request.message_id in self._subscribers:
del self._subscribers[request.message_id]

def publish(self, response: quantum.QuantumRunStreamResponse) -> None:
if response.message_id not in self._subscribers:
return

future = self._subscribers.pop(response.message_id)
if not future.done():
future.set_result(response)

def publish_exception(self, exception: GoogleAPICallError) -> None:
"""Publishes an exception to all outstanding futures."""
for future in self._subscribers.values():
if not future.done():
future.set_exception(exception)
self._subscribers = {}


class StreamManager:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add unit tests for the public API of this class.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely. For this PR I focused on sketching a prototype to get high-level design review. The real implementations will include detailed unit tests.

def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient):
self._grpc_client = grpc_client
self._request_queue: asyncio.Queue = asyncio.Queue()
self._manage_stream_loop_future: Optional[duet.AwaitableFuture] = None
# TODO consider making the scope of response futures local to the relevant tasks rather than
# all of StreamManager.
self._response_demux = ResponseDemux()

@property
def _executor(self) -> AsyncioExecutor:
# We must re-use a single Executor due to multi-threading issues in gRPC
# clients: https://github.com/grpc/grpc/issues/25364.
return AsyncioExecutor.instance()

async def _request_iterator(self) -> AsyncIterator[quantum.QuantumRunStreamRequest]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest not making this a property, since that makes it seem like there is one request iterator, but actually we get a new iterator each time self._request_iterator is accessed. I think it'd be more clear to make this a method on document that we get a new iterator each time. It occurs to me that it might be better to make this a static method or module level helper function and have it take the queue as a parameter, rather than using attributes on self, because the use of this iterator is actually local to one iteration of the loop in _manage_stream.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow the last part - why would it better to use a static method or module-level function if it's local to one iteration? Isn't a new iterator created for each quantum_run_stream in both cases?

"""The request iterator for quantum_run_stream().

Every call of this method generates a new iterator."""
while True:
yield await self._request_queue.get()

async def _manage_stream(self):
"""Keeps the stream alive and routes responses to the appropriate request handler"""
while True:
try:
# TODO specify stream timeout below with exponential backoff
response_iterable = await self._grpc_client.quantum_run_stream(
self._request_iterator()
)
async for response in response_iterable:
logger.warning('publishing response to demux')
self._response_demux.publish(response)
except BaseException as e:
# TODO Close the request iterator to close the existing stream.
self._response_demux.publish_exception(e) # Raise to all request tasks
if isinstance(e, asyncio.CancelledError):
break

async def _make_request(
self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob
) -> quantum.QuantumResult:
"""This method is executed in a separate asyncio Task for each request."""
current_request = quantum.QuantumRunStreamRequest(
parent=project_name,
create_quantum_program_and_job=quantum.CreateQuantumProgramAndJobRequest(
parent=project_name, quantum_program=program, quantum_job=job
),
)
get_result_request = quantum.QuantumRunStreamRequest(
parent=project_name, get_quantum_result=quantum.GetQuantumResultRequest(parent=job.name)
)

response_future: Optional[asyncio.Future] = None
response: Optional[quantum.QuantumRunStreamResponse] = None
while response is None:
try:
logger.warn('Making request')
response_future = self._response_demux.subscribe(current_request)
await self._request_queue.put(current_request)
response = await response_future
logger.warning('Got response')

except GoogleAPICallError:
# TODO how to distinguish between program not found vs job not found?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are different Codes in a StreamError.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks!

# TODO Send a CreateProgramAndJobRequest or CreateJobRequest if either program or
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what this means. e.g., a "job doesn't exist" should only happen during a GetQuantumResultRequest, but that value won't have enough context to create a new job.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetQuantumResultRequest is only sent when a previous CreateProgramAndJobRequest failed for some reason (most likely a stream failure) within this asyncio task, so this task does have the full program and job context to recreate the job.

Getting a "job doesn't exist" means either the "create program" part or the "create job" part failed. We could try recreating just the job first by sending a CreateJobRequest, and if that fails with a "program doesn't exist" error, we can try the full CreateProgramAndJobRequest again.

# job doesn't exist.
# TODO add exponential backoff
logger.warn('Got GoogleAPICallError')
self._response_demux.unsubscribe(current_request)
current_request = get_result_request
continue

# Either when this request is canceled or the _manage_stream() loop is canceled.
except asyncio.CancelledError:
if response_future is not None:
response_future.cancel()
self._response_demux.unsubscribe(current_request)
await self._cancel(job.name)
return quantum.QuantumResult()

if response.result is not None:
logger.warning('Got result')
return response.result
# TODO handle QuantumJob response and retryable StreamError.

async def _cancel(self, job_name: str) -> None:
await self._grpc_client.cancel_quantum_job(quantum.CancelQuantumJobRequest(name=job_name))

def send(
self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob
) -> duet.AwaitableFuture[quantum.QuantumResult]:
"""Sends a request over the stream and returns a future for the result."""
if self._manage_stream_loop_future is None:
self._manage_stream_loop_future = self._executor.submit(self._manage_stream)
return self._executor.submit(self._make_request, project_name, program, job)

def stop(self) -> duet.AwaitableFuture[None]:
"""Stops and resets the stream manager."""
if self._manage_stream_loop_future is None:
return duet.completed_future(None)
self._manage_stream_loop_future.cancel()
return self._manage_stream_loop_future


class EngineClient:
"""Client for the Quantum Engine API handling protos and gRPC client.

Expand Down Expand Up @@ -148,6 +292,10 @@ async def make_client():

return self._executor.submit(make_client).result()

@cached_property
def _stream_manager(self) -> StreamManager:
return StreamManager(self.grpc_client)

async def _send_request_async(self, func: Callable[[_M], Awaitable[_R]], request: _M) -> _R:
"""Sends a request by invoking an asyncio callable."""
return await self._run_retry_async(func, request)
Expand Down Expand Up @@ -740,6 +888,56 @@ async def get_job_results_async(

get_job_results = duet.sync(get_job_results_async)

def run_job_over_stream(
self,
project_id: str,
program_id: str,
code: any_pb2.Any,
job_id: Optional[str], # TODO make this non-optional.
processor_ids: Sequence[str],
run_context: any_pb2.Any,
priority: Optional[int] = None,
description: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
) -> duet.AwaitableFuture[quantum.QuantumResult]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_manage_stream first receives these responses through the response iterable, which is then sent through the ResponseDemux to the _make_request task handling their corresponding jobs. Where response.result is checked inside _make_request, there will be additional clauses for failed jobs and stream errors.

For retryable errors, the appropriate stream request will be set as current_request and the loop continues, so that the request is sent over the stream.

Non-retryable errors are raised, which is propagated to the duet.AwaitableFuture for this task. The end user awaiting on this future will see the exception.

# Check program to run and program parameters.
if priority and not 0 <= priority < 1000:
raise ValueError('priority must be between 0 and 1000')

project_name = _project_name(project_id)

program_name = _program_name_from_ids(project_id, program_id) if program_id else ''
program = quantum.QuantumProgram(name=program_name, code=code)
if description:
program.description = description
if labels:
program.labels.update(labels)

job_name = _job_name_from_ids(project_id, program_id, job_id) if job_id else ''
job = quantum.QuantumJob(
name=job_name,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than setting the name to '' (an implementation detail of the proto) if not job_id, WDYT of handling this like the other optional fields below and assigning only if it's set?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like you mentioned in the other comment, I also think it's a good idea to require job_id to be set, so I think this comment no longer applies

scheduling_config=quantum.SchedulingConfig(
processor_selector=quantum.SchedulingConfig.ProcessorSelector(
processor_names=[
_processor_name_from_ids(project_id, processor_id)
for processor_id in processor_ids
]
)
),
run_context=run_context,
)
if priority:
job.scheduling_config.priority = priority
if description:
job.description = description
if labels:
job.labels.update(labels)

return self._stream_manager.send(project_name, program, job)

def stop_stream(self) -> duet.AwaitableFuture[None]:
return self._stream_manager.stop()

async def list_processors_async(self, project_id: str) -> List[quantum.QuantumProcessor]:
"""Returns a list of Processors that the user has visibility to in the
current Engine project. The names of these processors are used to
Expand Down Expand Up @@ -1133,3 +1331,14 @@ def _date_or_time_to_filter_expr(param_name: str, param: Union[datetime.datetime
f"type {type(param)}. Supported types: datetime.datetime and"
f"datetime.date"
)


def _get_job_path_from_stream_request(request: quantum.QuantumRunStreamRequest) -> str:
if 'create_quantum_program_and_job' in request:
return request.create_quantum_program_and_job.quantum_job.name
elif 'create_quantum_job' in request:
return request.create_quantum_job.quantum_job.name
elif 'get_quantum_result' in request:
return request.get_quantum_result.parent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understad the use of "job path" these would have to be unique across requests, since we use the job path as a key to keep track of futures, so if there is a collision we could lose a future and the client will hang waiting for it to complete. Is "parent" here unique across requests? What if a user makes two get requests for the same result?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Parent" here is also the full job path.

Thanks for catching the future collision issue. One idea is to reuse the existing future in ResponseDemux if the subscriber already exists, and still increment the message ID.

else:
raise ValueError(f'Unrecognized request type in request: {request}')
117 changes: 117 additions & 0 deletions cirq-google/cirq_google/engine/engine_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for EngineClient."""
from typing import AsyncIterable, AsyncIterator, Awaitable, List

import asyncio
from asyncio.log import logger
import datetime
from unittest import mock

Expand All @@ -25,6 +28,7 @@

from cirq_google.engine.engine_client import EngineClient, EngineException
from cirq_google.cloud import quantum
from cirq_google.cloud.quantum_v1alpha1.types import engine


def setup_mock_(client_constructor):
Expand All @@ -33,6 +37,48 @@ def setup_mock_(client_constructor):
return grpc_client


def setup_fake_quantum_run_stream_client(client_constructor, responses_and_exceptions):
grpc_client = _FakeQuantumRunStream(responses_and_exceptions)
client_constructor.return_value = grpc_client
return grpc_client


class _FakeQuantumRunStream:
def __init__(
self, responses_and_exceptions: List[engine.QuantumRunStreamResponse | BaseException]
):
self.stream_request_count = 0
self.cancel_requests: List[engine.CancelQuantumJobRequest] = []
self.responses_and_exceptions = responses_and_exceptions

def add_responses_and_exceptions(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this appending behavior needed, or could we make the list immutable and provide it at construction time?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, updated

self, responses_and_exceptions: List[engine.QuantumRunStreamResponse | BaseException]
):
self.responses_and_exceptions.extend(responses_and_exceptions)

async def quantum_run_stream(
self, requests: AsyncIterator[engine.QuantumRunStreamRequest] = None, **kwargs
) -> Awaitable[AsyncIterable[engine.QuantumRunStreamResponse]]:
async def run_async_iterator():
async for request in requests:
self.stream_request_count += 1
while not self.responses_and_exceptions:
await asyncio.sleep(0)
logger.warning('Responding')
response_or_exception = self.responses_and_exceptions.pop(0)
if isinstance(response_or_exception, BaseException):
raise response_or_exception
response_or_exception.message_id = request.message_id
yield response_or_exception

await asyncio.sleep(0)
return run_async_iterator()

async def cancel_quantum_job(self, request: engine.CancelQuantumJobRequest) -> None:
self.cancel_requests.append(request)
await asyncio.sleep(0)


@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_create_program(client_constructor):
grpc_client = setup_mock_(client_constructor)
Expand Down Expand Up @@ -1176,3 +1222,74 @@ def test_list_time_slots(client_constructor):

client = EngineClient()
assert client.list_time_slots('proj', 'processor0') == results


@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_run_job_over_stream_send_job_expects_result_response(client_constructor):
expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0')
mock_responses = [quantum.QuantumRunStreamResponse(result=expected_result)]
fake_client = setup_fake_quantum_run_stream_client(
client_constructor, responses_and_exceptions=mock_responses
)

code = any_pb2.Any()
run_context = any_pb2.Any()
labels = {'hello': 'world'}
client = EngineClient()

actual_result = client.run_job_over_stream(
'proj', 'prog', code, 'job0', ['processor0'], run_context, 10, 'A job', labels
).result(timeout=5)
client.stop_stream().result

assert actual_result == expected_result
assert fake_client.stream_request_count == 1


@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_run_job_over_stream_cancel_expects_engine_cancellation_rpc_call(client_constructor):
fake_client = setup_fake_quantum_run_stream_client(
client_constructor, responses_and_exceptions=[]
)

code = any_pb2.Any()
run_context = any_pb2.Any()
labels = {'hello': 'world'}
client = EngineClient()

result_future = client.run_job_over_stream(
'proj', 'prog', code, 'job0', ['processor0'], run_context, 10, 'A job', labels
)
result_future.cancel()
client.stop_stream().result

assert fake_client.cancel_requests[0] == quantum.CancelQuantumJobRequest(
name='projects/proj/programs/prog/jobs/job0'
)


@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True)
def test_run_job_over_stream_stream_broken_expects_retry_with_get_quantum_result(
client_constructor,
):
expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0')
mock_responses_and_exceptions = [
exceptions.Aborted('aborted'),
quantum.QuantumRunStreamResponse(result=expected_result),
]
fake_client = setup_fake_quantum_run_stream_client(
client_constructor, responses_and_exceptions=mock_responses_and_exceptions
)

code = any_pb2.Any()
run_context = any_pb2.Any()
labels = {'hello': 'world'}
client = EngineClient()

actual_result = client.run_job_over_stream(
'proj', 'prog', code, 'job0', ['processor0'], run_context, 10, 'A job', labels
).result(timeout=1)
client.stop_stream().result

assert actual_result == expected_result
assert fake_client.stream_request_count == 2