@@ -19,7 +19,8 @@ from google.protobuf import json_format
1919{% endif %}
2020from requests import __version__ as requests_version
2121import dataclasses
22- from typing import Callable, Dict, Optional, Sequence, Tuple, Union
22+ import re
23+ from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
2324import warnings
2425
2526try:
@@ -65,7 +66,7 @@ class {{ service.name }}RestInterceptor:
6566
6667 .. code-block:
6768 class MyCustom{{ service.name }}Interceptor({{ service.name }}RestInterceptor):
68- {% for _ , method in service .methods |dictsort if not (method .server_streaming or method .client_streaming ) %}
69+ {% for _ , method in service .methods |dictsort if not (method .server_streaming or method .client_streaming ) %}
6970 def pre_{{ method.name|snake_case }}(request, metadata):
7071 logging.log(f"Received request: {request}")
7172 return request, metadata
@@ -81,7 +82,7 @@ class {{ service.name }}RestInterceptor:
8182
8283
8384 """
84- {% for method in service .methods .values ()|sort (attribute ="name" ) if not (method .server_streaming or method .client_streaming ) %}
85+ {% for method in service .methods .values ()|sort (attribute ="name" ) if not (method .server_streaming or method .client_streaming ) %}
8586 def pre_{{ method.name|snake_case }}(self, request: {{method.input.ident}}, metadata: Sequence[Tuple[str, str]]) -> Tuple[{{method.input.ident}}, Sequence[Tuple[str, str]]]:
8687 """Pre-rpc interceptor for {{ method.name|snake_case }}
8788
@@ -175,6 +176,14 @@ class {{service.name}}RestTransport({{service.name}}Transport):
175176 # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc.
176177 # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the
177178 # credentials object
179+ maybe_url_match = re.match("^(?P<scheme >http(?:s)?://)?(?P<host >.*)$", host)
180+ if maybe_url_match is None:
181+ raise ValueError(f"Unexpected hostname structure: {host}") # pragma: NO COVER
182+
183+ url_match_items = maybe_url_match.groupdict()
184+
185+ host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host
186+
178187 super().__init__(
179188 host=host,
180189 credentials=credentials,
@@ -184,7 +193,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
184193 self._session = AuthorizedSession(
185194 self._credentials, default_host=self.DEFAULT_HOST)
186195 {% if service .has_lro %}
187- self._operations_client = None
196+ self._operations_client: Optional[operations_v1.AbstractOperationsClient] = None
188197 {% endif %}
189198 if client_cert_source_for_mtls:
190199 self._session.configure_mtls_channel(client_cert_source_for_mtls)
@@ -202,7 +211,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
202211 """
203212 # Only create a new client if we do not already have one.
204213 if self._operations_client is None:
205- http_options = {
214+ http_options: Dict[str, List[Dict[str, str]]] = {
206215 {% for selector , rules in api .http_options .items () %}
207216 {% if selector .startswith ('google.longrunning.Operations' ) %}
208217 '{{ selector }}': [
@@ -238,9 +247,10 @@ class {{service.name}}RestTransport({{service.name}}Transport):
238247 def __hash__(self):
239248 return hash("{{method.name}}")
240249
250+
241251 {% if not (method .server_streaming or method .client_streaming ) %}
242252 {% if method .input .required_fields %}
243- __REQUIRED_FIELDS_DEFAULT_VALUES = {
253+ __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, str] = {
244254 {% for req_field in method .input .required_fields if req_field .is_primitive and req_field .name in method .query_params %}
245255 "{{ req_field.name | camel_case }}" : {% if req_field .field_pb .type == 9 %} "{{req_field.field_pb.default_value }}"{% else %} {{ req_field.type.python_type(req_field.field_pb.default_value or 0) }}{% endif %} ,{# default is str #}
246256 {% endfor %}
@@ -258,7 +268,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
258268 retry: OptionalRetry=gapic_v1.method.DEFAULT,
259269 timeout: float=None,
260270 metadata: Sequence[Tuple[str, str]]=(),
261- ) -> {{method.output.ident}}:
271+ ){% if not method . void %} -> {{method.output.ident}} {% endif % } :
262272 {% if method .http_options and not (method .server_streaming or method .client_streaming ) %}
263273 r"""Call the {{- ' ' -}}
264274 {{ (method.name|snake_case).replace('_',' ')|wrap(
@@ -282,7 +292,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
282292 {% endif %}
283293 """
284294
285- http_options = [
295+ http_options: List[Dict[str, str]] = [
286296 {% - for rule in method .http_options %} {
287297 'method': '{{ rule.method }}',
288298 'uri': '{{ rule.uri }}',
@@ -330,8 +340,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
330340 headers = dict(metadata)
331341 headers['Content-Type'] = 'application/json'
332342 response = getattr(self._session, method)(
333- # Replace with proper schema configuration (http/https) logic
334- "https://{host}{uri}".format(host=self._host, uri=uri),
343+ "{host}{uri}".format(host=self._host, uri=uri),
335344 timeout=timeout,
336345 headers=headers,
337346 params=rest_helpers.flatten_query_params(query_params),
@@ -344,6 +353,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
344353 # subclass.
345354 if response.status_code >= 400:
346355 raise core_exceptions.from_http_response(response)
356+
347357 {% if not method .void %}
348358 # Return the response
349359 {% if method .lro %}
@@ -357,6 +367,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
357367 {% endif %} {# method.lro #}
358368 resp = self._interceptor.post_{{ method.name|snake_case }}(resp)
359369 return resp
370+
360371 {% endif %} {# method.void #}
361372 {% else %} {# method.http_options and not (method.server_streaming or method.client_streaming) #}
362373 {% if not method .http_options %}
@@ -384,7 +395,9 @@ class {{service.name}}RestTransport({{service.name}}Transport):
384395 if not stub:
385396 stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor)
386397
387- return stub
398+ # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here.
399+ # In C++ this would require a dynamic_cast
400+ return stub # type: ignore
388401
389402 {% endfor %}
390403
0 commit comments