-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Stream client prototype #2: asyncio-heavy #6145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5b7147f
d69d5ca
cfa8a52
658f1b3
b2e0b11
2f2263d
f5295b5
748c591
7f2295b
4c2ed33
d54b10f
9a65eb0
0778501
7300522
8c8e901
b0c8ed4
31ad529
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,11 +13,13 @@ | |
# limitations under the License. | ||
|
||
import asyncio | ||
from asyncio.log import logger | ||
import datetime | ||
import sys | ||
import threading | ||
from typing import ( | ||
AsyncIterable, | ||
AsyncIterator, | ||
Awaitable, | ||
Callable, | ||
Dict, | ||
|
@@ -42,6 +44,8 @@ | |
|
||
_M = TypeVar('_M', bound=proto.Message) | ||
_R = TypeVar('_R') | ||
JobPath = str | ||
MessageId = str | ||
|
||
|
||
class EngineException(Exception): | ||
|
@@ -95,6 +99,146 @@ def instance(cls): | |
return cls._instance | ||
|
||
|
||
class ResponseDemux: | ||
"""A event demultiplexer for QuantumRunStreamResponses, as part of the async reactor pattern.""" | ||
|
||
def __init__(self): | ||
self._subscribers: Dict[MessageId, asyncio.Future] = {} | ||
self._next_available_message_id = 0 | ||
|
||
def subscribe(self, request: quantum.QuantumRunStreamRequest) -> asyncio.Future: | ||
"""Assumes the message ID has not been set.""" | ||
request.message_id = str(self._next_available_message_id) | ||
response_future: asyncio.Future = asyncio.get_running_loop().create_future() | ||
self._subscribers[request.message_id] = response_future | ||
self._next_available_message_id += 1 | ||
return response_future | ||
|
||
def unsubscribe(self, request: quantum.QuantumRunStreamRequest) -> None: | ||
if request.message_id in self._subscribers: | ||
del self._subscribers[request.message_id] | ||
|
||
def publish(self, response: quantum.QuantumRunStreamResponse) -> None: | ||
if response.message_id not in self._subscribers: | ||
return | ||
|
||
future = self._subscribers.pop(response.message_id) | ||
if not future.done(): | ||
future.set_result(response) | ||
|
||
def publish_exception(self, exception: GoogleAPICallError) -> None: | ||
"""Publishes an exception to all outstanding futures.""" | ||
for future in self._subscribers.values(): | ||
if not future.done(): | ||
future.set_exception(exception) | ||
self._subscribers = {} | ||
|
||
|
||
class StreamManager: | ||
def __init__(self, grpc_client: quantum.QuantumEngineServiceAsyncClient): | ||
self._grpc_client = grpc_client | ||
self._request_queue: asyncio.Queue = asyncio.Queue() | ||
self._manage_stream_loop_future: Optional[duet.AwaitableFuture] = None | ||
# TODO consider making the scope of response futures local to the relevant tasks rather than | ||
# all of StreamManager. | ||
self._response_demux = ResponseDemux() | ||
|
||
@property | ||
def _executor(self) -> AsyncioExecutor: | ||
# We must re-use a single Executor due to multi-threading issues in gRPC | ||
# clients: https://github.com/grpc/grpc/issues/25364. | ||
return AsyncioExecutor.instance() | ||
|
||
async def _request_iterator(self) -> AsyncIterator[quantum.QuantumRunStreamRequest]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would suggest not making this a property, since that makes it seem like there is one request iterator, but actually we get a new iterator each time There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I follow the last part - why would it better to use a static method or module-level function if it's local to one iteration? Isn't a new iterator created for each |
||
"""The request iterator for quantum_run_stream(). | ||
|
||
Every call of this method generates a new iterator.""" | ||
while True: | ||
yield await self._request_queue.get() | ||
|
||
async def _manage_stream(self): | ||
"""Keeps the stream alive and routes responses to the appropriate request handler""" | ||
while True: | ||
try: | ||
# TODO specify stream timeout below with exponential backoff | ||
response_iterable = await self._grpc_client.quantum_run_stream( | ||
self._request_iterator() | ||
) | ||
async for response in response_iterable: | ||
logger.warning('publishing response to demux') | ||
self._response_demux.publish(response) | ||
except BaseException as e: | ||
# TODO Close the request iterator to close the existing stream. | ||
self._response_demux.publish_exception(e) # Raise to all request tasks | ||
if isinstance(e, asyncio.CancelledError): | ||
break | ||
|
||
async def _make_request( | ||
self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob | ||
) -> quantum.QuantumResult: | ||
"""This method is executed in a separate asyncio Task for each request.""" | ||
current_request = quantum.QuantumRunStreamRequest( | ||
parent=project_name, | ||
create_quantum_program_and_job=quantum.CreateQuantumProgramAndJobRequest( | ||
parent=project_name, quantum_program=program, quantum_job=job | ||
), | ||
) | ||
get_result_request = quantum.QuantumRunStreamRequest( | ||
parent=project_name, get_quantum_result=quantum.GetQuantumResultRequest(parent=job.name) | ||
) | ||
|
||
response_future: Optional[asyncio.Future] = None | ||
response: Optional[quantum.QuantumRunStreamResponse] = None | ||
while response is None: | ||
try: | ||
logger.warn('Making request') | ||
response_future = self._response_demux.subscribe(current_request) | ||
await self._request_queue.put(current_request) | ||
response = await response_future | ||
logger.warning('Got response') | ||
|
||
except GoogleAPICallError: | ||
# TODO how to distinguish between program not found vs job not found? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are different There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice, thanks! |
||
# TODO Send a CreateProgramAndJobRequest or CreateJobRequest if either program or | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure what this means. e.g., a "job doesn't exist" should only happen during a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Getting a "job doesn't exist" means either the "create program" part or the "create job" part failed. We could try recreating just the job first by sending a |
||
# job doesn't exist. | ||
# TODO add exponential backoff | ||
logger.warn('Got GoogleAPICallError') | ||
self._response_demux.unsubscribe(current_request) | ||
current_request = get_result_request | ||
continue | ||
|
||
# Either when this request is canceled or the _manage_stream() loop is canceled. | ||
except asyncio.CancelledError: | ||
if response_future is not None: | ||
response_future.cancel() | ||
self._response_demux.unsubscribe(current_request) | ||
await self._cancel(job.name) | ||
return quantum.QuantumResult() | ||
|
||
if response.result is not None: | ||
logger.warning('Got result') | ||
return response.result | ||
# TODO handle QuantumJob response and retryable StreamError. | ||
|
||
async def _cancel(self, job_name: str) -> None: | ||
await self._grpc_client.cancel_quantum_job(quantum.CancelQuantumJobRequest(name=job_name)) | ||
|
||
def send( | ||
self, project_name: str, program: quantum.QuantumProgram, job: quantum.QuantumJob | ||
) -> duet.AwaitableFuture[quantum.QuantumResult]: | ||
"""Sends a request over the stream and returns a future for the result.""" | ||
if self._manage_stream_loop_future is None: | ||
self._manage_stream_loop_future = self._executor.submit(self._manage_stream) | ||
return self._executor.submit(self._make_request, project_name, program, job) | ||
|
||
def stop(self) -> duet.AwaitableFuture[None]: | ||
"""Stops and resets the stream manager.""" | ||
if self._manage_stream_loop_future is None: | ||
return duet.completed_future(None) | ||
self._manage_stream_loop_future.cancel() | ||
return self._manage_stream_loop_future | ||
|
||
|
||
class EngineClient: | ||
"""Client for the Quantum Engine API handling protos and gRPC client. | ||
|
||
|
@@ -148,6 +292,10 @@ async def make_client(): | |
|
||
return self._executor.submit(make_client).result() | ||
|
||
@cached_property | ||
def _stream_manager(self) -> StreamManager: | ||
return StreamManager(self.grpc_client) | ||
|
||
async def _send_request_async(self, func: Callable[[_M], Awaitable[_R]], request: _M) -> _R: | ||
"""Sends a request by invoking an asyncio callable.""" | ||
return await self._run_retry_async(func, request) | ||
|
@@ -740,6 +888,56 @@ async def get_job_results_async( | |
|
||
get_job_results = duet.sync(get_job_results_async) | ||
|
||
def run_job_over_stream( | ||
self, | ||
project_id: str, | ||
program_id: str, | ||
code: any_pb2.Any, | ||
job_id: Optional[str], # TODO make this non-optional. | ||
processor_ids: Sequence[str], | ||
run_context: any_pb2.Any, | ||
priority: Optional[int] = None, | ||
description: Optional[str] = None, | ||
labels: Optional[Dict[str, str]] = None, | ||
) -> duet.AwaitableFuture[quantum.QuantumResult]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How are failed jobs or errors conveyed through this interface? https://source.corp.google.com/piper///depot/google3/google/cloud/quantum/v1alpha1/engine.proto;rcl=389782031;l=1207 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
For retryable errors, the appropriate stream request will be set as Non-retryable errors are raised, which is propagated to the |
||
# Check program to run and program parameters. | ||
if priority and not 0 <= priority < 1000: | ||
raise ValueError('priority must be between 0 and 1000') | ||
|
||
project_name = _project_name(project_id) | ||
|
||
program_name = _program_name_from_ids(project_id, program_id) if program_id else '' | ||
program = quantum.QuantumProgram(name=program_name, code=code) | ||
if description: | ||
program.description = description | ||
if labels: | ||
program.labels.update(labels) | ||
|
||
job_name = _job_name_from_ids(project_id, program_id, job_id) if job_id else '' | ||
job = quantum.QuantumJob( | ||
name=job_name, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than setting the name to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Like you mentioned in the other comment, I also think it's a good idea to require |
||
scheduling_config=quantum.SchedulingConfig( | ||
processor_selector=quantum.SchedulingConfig.ProcessorSelector( | ||
processor_names=[ | ||
_processor_name_from_ids(project_id, processor_id) | ||
for processor_id in processor_ids | ||
] | ||
) | ||
), | ||
run_context=run_context, | ||
) | ||
if priority: | ||
job.scheduling_config.priority = priority | ||
if description: | ||
job.description = description | ||
if labels: | ||
job.labels.update(labels) | ||
|
||
return self._stream_manager.send(project_name, program, job) | ||
|
||
def stop_stream(self) -> duet.AwaitableFuture[None]: | ||
return self._stream_manager.stop() | ||
|
||
async def list_processors_async(self, project_id: str) -> List[quantum.QuantumProcessor]: | ||
"""Returns a list of Processors that the user has visibility to in the | ||
current Engine project. The names of these processors are used to | ||
|
@@ -1133,3 +1331,14 @@ def _date_or_time_to_filter_expr(param_name: str, param: Union[datetime.datetime | |
f"type {type(param)}. Supported types: datetime.datetime and" | ||
f"datetime.date" | ||
) | ||
|
||
|
||
def _get_job_path_from_stream_request(request: quantum.QuantumRunStreamRequest) -> str: | ||
if 'create_quantum_program_and_job' in request: | ||
return request.create_quantum_program_and_job.quantum_job.name | ||
elif 'create_quantum_job' in request: | ||
return request.create_quantum_job.quantum_job.name | ||
elif 'get_quantum_result' in request: | ||
return request.get_quantum_result.parent | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I understad the use of "job path" these would have to be unique across requests, since we use the job path as a key to keep track of futures, so if there is a collision we could lose a future and the client will hang waiting for it to complete. Is "parent" here unique across requests? What if a user makes two get requests for the same result? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Parent" here is also the full job path. Thanks for catching the future collision issue. One idea is to reuse the existing future in |
||
else: | ||
raise ValueError(f'Unrecognized request type in request: {request}') |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,10 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for EngineClient.""" | ||
from typing import AsyncIterable, AsyncIterator, Awaitable, List | ||
|
||
import asyncio | ||
from asyncio.log import logger | ||
import datetime | ||
from unittest import mock | ||
|
||
|
@@ -25,6 +28,7 @@ | |
|
||
from cirq_google.engine.engine_client import EngineClient, EngineException | ||
from cirq_google.cloud import quantum | ||
from cirq_google.cloud.quantum_v1alpha1.types import engine | ||
|
||
|
||
def setup_mock_(client_constructor): | ||
|
@@ -33,6 +37,48 @@ def setup_mock_(client_constructor): | |
return grpc_client | ||
|
||
|
||
def setup_fake_quantum_run_stream_client(client_constructor, responses_and_exceptions): | ||
grpc_client = _FakeQuantumRunStream(responses_and_exceptions) | ||
client_constructor.return_value = grpc_client | ||
return grpc_client | ||
|
||
|
||
class _FakeQuantumRunStream: | ||
def __init__( | ||
self, responses_and_exceptions: List[engine.QuantumRunStreamResponse | BaseException] | ||
): | ||
self.stream_request_count = 0 | ||
self.cancel_requests: List[engine.CancelQuantumJobRequest] = [] | ||
self.responses_and_exceptions = responses_and_exceptions | ||
|
||
def add_responses_and_exceptions( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this appending behavior needed, or could we make the list immutable and provide it at construction time? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, updated |
||
self, responses_and_exceptions: List[engine.QuantumRunStreamResponse | BaseException] | ||
): | ||
self.responses_and_exceptions.extend(responses_and_exceptions) | ||
|
||
async def quantum_run_stream( | ||
self, requests: AsyncIterator[engine.QuantumRunStreamRequest] = None, **kwargs | ||
) -> Awaitable[AsyncIterable[engine.QuantumRunStreamResponse]]: | ||
async def run_async_iterator(): | ||
async for request in requests: | ||
self.stream_request_count += 1 | ||
while not self.responses_and_exceptions: | ||
await asyncio.sleep(0) | ||
logger.warning('Responding') | ||
response_or_exception = self.responses_and_exceptions.pop(0) | ||
if isinstance(response_or_exception, BaseException): | ||
raise response_or_exception | ||
response_or_exception.message_id = request.message_id | ||
yield response_or_exception | ||
|
||
await asyncio.sleep(0) | ||
return run_async_iterator() | ||
|
||
async def cancel_quantum_job(self, request: engine.CancelQuantumJobRequest) -> None: | ||
self.cancel_requests.append(request) | ||
await asyncio.sleep(0) | ||
|
||
|
||
@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) | ||
def test_create_program(client_constructor): | ||
grpc_client = setup_mock_(client_constructor) | ||
|
@@ -1176,3 +1222,74 @@ def test_list_time_slots(client_constructor): | |
|
||
client = EngineClient() | ||
assert client.list_time_slots('proj', 'processor0') == results | ||
|
||
|
||
@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) | ||
def test_run_job_over_stream_send_job_expects_result_response(client_constructor): | ||
expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') | ||
mock_responses = [quantum.QuantumRunStreamResponse(result=expected_result)] | ||
fake_client = setup_fake_quantum_run_stream_client( | ||
client_constructor, responses_and_exceptions=mock_responses | ||
) | ||
|
||
code = any_pb2.Any() | ||
run_context = any_pb2.Any() | ||
labels = {'hello': 'world'} | ||
client = EngineClient() | ||
|
||
actual_result = client.run_job_over_stream( | ||
'proj', 'prog', code, 'job0', ['processor0'], run_context, 10, 'A job', labels | ||
).result(timeout=5) | ||
client.stop_stream().result | ||
|
||
assert actual_result == expected_result | ||
assert fake_client.stream_request_count == 1 | ||
|
||
|
||
@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) | ||
def test_run_job_over_stream_cancel_expects_engine_cancellation_rpc_call(client_constructor): | ||
fake_client = setup_fake_quantum_run_stream_client( | ||
client_constructor, responses_and_exceptions=[] | ||
) | ||
|
||
code = any_pb2.Any() | ||
run_context = any_pb2.Any() | ||
labels = {'hello': 'world'} | ||
client = EngineClient() | ||
|
||
result_future = client.run_job_over_stream( | ||
'proj', 'prog', code, 'job0', ['processor0'], run_context, 10, 'A job', labels | ||
) | ||
result_future.cancel() | ||
client.stop_stream().result | ||
|
||
assert fake_client.cancel_requests[0] == quantum.CancelQuantumJobRequest( | ||
name='projects/proj/programs/prog/jobs/job0' | ||
) | ||
|
||
|
||
@mock.patch.object(quantum, 'QuantumEngineServiceAsyncClient', autospec=True) | ||
def test_run_job_over_stream_stream_broken_expects_retry_with_get_quantum_result( | ||
client_constructor, | ||
): | ||
expected_result = quantum.QuantumResult(parent='projects/proj/programs/prog/jobs/job0') | ||
mock_responses_and_exceptions = [ | ||
exceptions.Aborted('aborted'), | ||
quantum.QuantumRunStreamResponse(result=expected_result), | ||
] | ||
fake_client = setup_fake_quantum_run_stream_client( | ||
client_constructor, responses_and_exceptions=mock_responses_and_exceptions | ||
) | ||
|
||
code = any_pb2.Any() | ||
run_context = any_pb2.Any() | ||
labels = {'hello': 'world'} | ||
client = EngineClient() | ||
|
||
actual_result = client.run_job_over_stream( | ||
'proj', 'prog', code, 'job0', ['processor0'], run_context, 10, 'A job', labels | ||
).result(timeout=1) | ||
client.stop_stream().result | ||
|
||
assert actual_result == expected_result | ||
assert fake_client.stream_request_count == 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add unit tests for the public API of this class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely. For this PR I focused on sketching a prototype to get high-level design review. The real implementations will include detailed unit tests.