Skip to content

Commit

Permalink
feat: add support for asynchronous rest streaming (#686)
Browse files Browse the repository at this point in the history
* duplicating file to base

* restore original file

* duplicate file to async

* restore original file

* duplicate test file for async

* restore test file

* feat: add support for asynchronous rest streaming

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* fix naming issue

* fix import module name

* pull auth feature branch

* revert setup file

* address PR comments

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* run black

* address PR comments

* update nox coverage

* address PR comments

* fix nox session name in workflow

* use https for remote repo

* add context manager methods

* address PR comments

* update auth error versions

* update import error

---------

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
  • Loading branch information
ohmayr and gcf-owl-bot[bot] authored Sep 18, 2024
1 parent e542124 commit 1b7bb6d
Show file tree
Hide file tree
Showing 8 changed files with 679 additions and 128 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unittest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
option: ["", "_grpc_gcp", "_wo_grpc", "_with_prerelease_deps"]
option: ["", "_grpc_gcp", "_wo_grpc", "_with_prerelease_deps", "_with_auth_aio"]
python:
- "3.7"
- "3.8"
Expand Down
118 changes: 118 additions & 0 deletions google/api_core/_rest_streaming_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helpers for server-side streaming in REST."""

from collections import deque
import string
from typing import Deque, Union
import types

import proto
import google.protobuf.message
from google.protobuf.json_format import Parse


class BaseResponseIterator:
"""Base Iterator over REST API responses. This class should not be used directly.
Args:
response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response
class expected to be returned from an API.
Raises:
ValueError: If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`.
"""

def __init__(
self,
response_message_cls: Union[proto.Message, google.protobuf.message.Message],
):
self._response_message_cls = response_message_cls
# Contains a list of JSON responses ready to be sent to user.
self._ready_objs: Deque[str] = deque()
# Current JSON response being built.
self._obj = ""
# Keeps track of the nesting level within a JSON object.
self._level = 0
# Keeps track whether HTTP response is currently sending values
# inside of a string value.
self._in_string = False
# Whether an escape symbol "\" was encountered.
self._escape_next = False

self._grab = types.MethodType(self._create_grab(), self)

def _process_chunk(self, chunk: str):
if self._level == 0:
if chunk[0] != "[":
raise ValueError(
"Can only parse array of JSON objects, instead got %s" % chunk
)
for char in chunk:
if char == "{":
if self._level == 1:
# Level 1 corresponds to the outermost JSON object
# (i.e. the one we care about).
self._obj = ""
if not self._in_string:
self._level += 1
self._obj += char
elif char == "}":
self._obj += char
if not self._in_string:
self._level -= 1
if not self._in_string and self._level == 1:
self._ready_objs.append(self._obj)
elif char == '"':
# Helps to deal with an escaped quotes inside of a string.
if not self._escape_next:
self._in_string = not self._in_string
self._obj += char
elif char in string.whitespace:
if self._in_string:
self._obj += char
elif char == "[":
if self._level == 0:
self._level += 1
else:
self._obj += char
elif char == "]":
if self._level == 1:
self._level -= 1
else:
self._obj += char
else:
self._obj += char
self._escape_next = not self._escape_next if char == "\\" else False

def _create_grab(self):
if issubclass(self._response_message_cls, proto.Message):

def grab(this):
return this._response_message_cls.from_json(
this._ready_objs.popleft(), ignore_unknown_fields=True
)

return grab
elif issubclass(self._response_message_cls, google.protobuf.message.Message):

def grab(this):
return Parse(this._ready_objs.popleft(), this._response_message_cls())

return grab
else:
raise ValueError(
"Response message class must be a subclass of proto.Message or google.protobuf.message.Message."
)
82 changes: 8 additions & 74 deletions google/api_core/rest_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,15 @@

"""Helpers for server-side streaming in REST."""

from collections import deque
import string
from typing import Deque, Union
from typing import Union

import proto
import requests
import google.protobuf.message
from google.protobuf.json_format import Parse
from google.api_core._rest_streaming_base import BaseResponseIterator


class ResponseIterator:
class ResponseIterator(BaseResponseIterator):
"""Iterator over REST API responses.
Args:
Expand All @@ -33,7 +31,8 @@ class ResponseIterator:
class expected to be returned from an API.
Raises:
ValueError: If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`.
ValueError:
- If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`.
"""

def __init__(
Expand All @@ -42,68 +41,16 @@ def __init__(
response_message_cls: Union[proto.Message, google.protobuf.message.Message],
):
self._response = response
self._response_message_cls = response_message_cls
# Inner iterator over HTTP response's content.
self._response_itr = self._response.iter_content(decode_unicode=True)
# Contains a list of JSON responses ready to be sent to user.
self._ready_objs: Deque[str] = deque()
# Current JSON response being built.
self._obj = ""
# Keeps track of the nesting level within a JSON object.
self._level = 0
# Keeps track whether HTTP response is currently sending values
# inside of a string value.
self._in_string = False
# Whether an escape symbol "\" was encountered.
self._escape_next = False
super(ResponseIterator, self).__init__(
response_message_cls=response_message_cls
)

def cancel(self):
"""Cancel existing streaming operation."""
self._response.close()

def _process_chunk(self, chunk: str):
if self._level == 0:
if chunk[0] != "[":
raise ValueError(
"Can only parse array of JSON objects, instead got %s" % chunk
)
for char in chunk:
if char == "{":
if self._level == 1:
# Level 1 corresponds to the outermost JSON object
# (i.e. the one we care about).
self._obj = ""
if not self._in_string:
self._level += 1
self._obj += char
elif char == "}":
self._obj += char
if not self._in_string:
self._level -= 1
if not self._in_string and self._level == 1:
self._ready_objs.append(self._obj)
elif char == '"':
# Helps to deal with an escaped quotes inside of a string.
if not self._escape_next:
self._in_string = not self._in_string
self._obj += char
elif char in string.whitespace:
if self._in_string:
self._obj += char
elif char == "[":
if self._level == 0:
self._level += 1
else:
self._obj += char
elif char == "]":
if self._level == 1:
self._level -= 1
else:
self._obj += char
else:
self._obj += char
self._escape_next = not self._escape_next if char == "\\" else False

def __next__(self):
while not self._ready_objs:
try:
Expand All @@ -115,18 +62,5 @@ def __next__(self):
raise e
return self._grab()

def _grab(self):
# Add extra quotes to make json.loads happy.
if issubclass(self._response_message_cls, proto.Message):
return self._response_message_cls.from_json(
self._ready_objs.popleft(), ignore_unknown_fields=True
)
elif issubclass(self._response_message_cls, google.protobuf.message.Message):
return Parse(self._ready_objs.popleft(), self._response_message_cls())
else:
raise ValueError(
"Response message class must be a subclass of proto.Message or google.protobuf.message.Message."
)

def __iter__(self):
return self
83 changes: 83 additions & 0 deletions google/api_core/rest_streaming_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helpers for asynchronous server-side streaming in REST."""

from typing import Union

import proto

try:
import google.auth.aio.transport
except ImportError as e: # pragma: NO COVER
raise ImportError(
"google-auth>=2.35.0 is required to use asynchronous rest streaming."
) from e

import google.protobuf.message
from google.api_core._rest_streaming_base import BaseResponseIterator


class AsyncResponseIterator(BaseResponseIterator):
"""Asynchronous Iterator over REST API responses.
Args:
response (google.auth.aio.transport.Response): An API response object.
response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response
class expected to be returned from an API.
Raises:
ValueError:
- If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`.
"""

def __init__(
self,
response: google.auth.aio.transport.Response,
response_message_cls: Union[proto.Message, google.protobuf.message.Message],
):
self._response = response
self._chunk_size = 1024
self._response_itr = self._response.content().__aiter__()
super(AsyncResponseIterator, self).__init__(
response_message_cls=response_message_cls
)

async def __aenter__(self):
return self

async def cancel(self):
"""Cancel existing streaming operation."""
await self._response.close()

async def __anext__(self):
while not self._ready_objs:
try:
chunk = await self._response_itr.__anext__()
chunk = chunk.decode("utf-8")
self._process_chunk(chunk)
except StopAsyncIteration as e:
if self._level > 0:
raise ValueError("i Unfinished stream: %s" % self._obj)
raise e
except ValueError as e:
raise e
return self._grab()

def __aiter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
"""Cancel existing async streaming operation."""
await self._response.close()
14 changes: 13 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"unit",
"unit_grpc_gcp",
"unit_wo_grpc",
"unit_with_auth_aio",
"cover",
"pytype",
"mypy",
Expand Down Expand Up @@ -109,7 +110,7 @@ def install_prerelease_dependencies(session, constraints_path):
session.install(*other_deps)


def default(session, install_grpc=True, prerelease=False):
def default(session, install_grpc=True, prerelease=False, install_auth_aio=False):
"""Default unit test session.
This is intended to be run **without** an interpreter set, so
Expand Down Expand Up @@ -144,6 +145,11 @@ def default(session, install_grpc=True, prerelease=False):
f"{constraints_dir}/constraints-{session.python}.txt",
)

if install_auth_aio:
session.install(
"google-auth @ git+https://git@github.com/googleapis/google-auth-library-python@8833ad6f92c3300d6645355994c7db2356bd30ad"
)

# Print out package versions of dependencies
session.run(
"python", "-c", "import google.protobuf; print(google.protobuf.__version__)"
Expand Down Expand Up @@ -229,6 +235,12 @@ def unit_wo_grpc(session):
default(session, install_grpc=False)


@nox.session(python=PYTHON_VERSIONS)
def unit_with_auth_aio(session):
"""Run the unit test suite with google.auth.aio installed"""
default(session, install_auth_aio=True)


@nox.session(python=DEFAULT_PYTHON_VERSION)
def lint_setup_py(session):
"""Verify that setup.py is valid (including RST check)."""
Expand Down
Loading

0 comments on commit 1b7bb6d

Please sign in to comment.