Skip to content

Commit 55d2bc6

Browse files
authored
feat: add snippet index (#1121)
This PR adds snippet_metadata.proto and another samplegen utils class to store the snippets so that they can be looked up by library template code. PRs to begin generating metadata and add the samples to the library docstrings will follow (I originally planned it for one PR, but the changeset was a bit too big).
1 parent 249f069 commit 55d2bc6

File tree

5 files changed

+1564
-1
lines changed

5 files changed

+1564
-1
lines changed
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional, Dict
16+
import re
17+
18+
from google.protobuf import json_format
19+
20+
from gapic.schema import api, metadata
21+
from gapic.samplegen_utils import snippet_metadata_pb2 # type: ignore
22+
from gapic.samplegen_utils import types
23+
24+
25+
CLIENT_INIT_RE = re.compile(r"^\s+# Create a client")
26+
REQUEST_INIT_RE = re.compile(r"^\s+# Initialize request argument\(s\)")
27+
REQUEST_EXEC_RE = re.compile(r"^\s+# Make the request")
28+
RESPONSE_HANDLING_RE = re.compile(r"^\s+# Handle response")
29+
30+
31+
class Snippet:
32+
"""A single snippet and its metadata.
33+
34+
Attributes:
35+
sample_str (str): The full text of the code snippet.
36+
metadata (snippet_metadata_pb2.Snippet): The snippet's metadata.
37+
"""
38+
39+
def __init__(self, sample_str: str, sample_metadata):
40+
self.sample_str = sample_str
41+
self.metadata = sample_metadata
42+
self._parse_snippet_segments()
43+
44+
def _parse_snippet_segments(self):
45+
"""Parse sections of the sample string and update metadata"""
46+
self.sample_lines = self.sample_str.splitlines(keepends=True)
47+
48+
self._full_snippet = snippet_metadata_pb2.Snippet.Segment(
49+
type=snippet_metadata_pb2.Snippet.Segment.SegmentType.FULL)
50+
self._short_snippet = snippet_metadata_pb2.Snippet.Segment(
51+
type=snippet_metadata_pb2.Snippet.Segment.SegmentType.SHORT)
52+
self._client_init = snippet_metadata_pb2.Snippet.Segment(
53+
type=snippet_metadata_pb2.Snippet.Segment.SegmentType.CLIENT_INITIALIZATION)
54+
self._request_init = snippet_metadata_pb2.Snippet.Segment(
55+
type=snippet_metadata_pb2.Snippet.Segment.SegmentType.REQUEST_INITIALIZATION)
56+
self._request_exec = snippet_metadata_pb2.Snippet.Segment(
57+
type=snippet_metadata_pb2.Snippet.Segment.SegmentType.REQUEST_EXECUTION)
58+
self._response_handling = snippet_metadata_pb2.Snippet.Segment(
59+
type=snippet_metadata_pb2.Snippet.Segment.SegmentType.RESPONSE_HANDLING,
60+
end=len(self.sample_lines)
61+
)
62+
63+
# Index starts at 1 since these represent line numbers
64+
for i, line in enumerate(self.sample_lines, start=1):
65+
if line.startswith("# [START"): # do not include region tag lines
66+
self._full_snippet.start = i + 1
67+
self._short_snippet.start = self._full_snippet.start
68+
elif line.startswith("# [END"):
69+
self._full_snippet.end = i - 1
70+
self._short_snippet.end = self._full_snippet.end
71+
elif CLIENT_INIT_RE.match(line):
72+
self._client_init.start = i
73+
elif REQUEST_INIT_RE.match(line):
74+
self._client_init.end = i - 1
75+
self._request_init.start = i
76+
elif REQUEST_EXEC_RE.match(line):
77+
self._request_init.end = i - 1
78+
self._request_exec.start = i
79+
elif RESPONSE_HANDLING_RE.match(line):
80+
self._request_exec.end = i - 1
81+
self._response_handling.start = i
82+
83+
self.metadata.segments.extend([self._full_snippet, self._short_snippet, self._client_init,
84+
self._request_init, self._request_exec, self._response_handling])
85+
86+
@property
87+
def full_snippet(self) -> str:
88+
"""The portion between the START and END region tags."""
89+
start_idx = self._full_snippet.start - 1
90+
end_idx = self._full_snippet.end
91+
return "".join(self.sample_lines[start_idx:end_idx])
92+
93+
94+
class SnippetIndex:
95+
"""An index of all the snippets for an API.
96+
97+
Attributes:
98+
metadata_index (snippet_metadata_pb2.Index): The snippet metadata index.
99+
"""
100+
101+
def __init__(self, api_schema: api.API):
102+
self.metadata_index = snippet_metadata_pb2.Index() # type: ignore
103+
104+
# Construct a dictionary to insert samples into based on the API schema
105+
# NOTE: In the future we expect the generator to support configured samples,
106+
# which will result in more than one sample variant per RPC. At that
107+
# time a different data structure (and re-writes of add_snippet and get_snippet)
108+
# will be needed.
109+
self._index: Dict[str, Dict[str, Dict[str, Optional[Snippet]]]] = {}
110+
111+
self._index = {
112+
s.name: {m: {"sync": None, "async": None} for m in s.methods}
113+
for s in api_schema.services.values()
114+
}
115+
116+
def add_snippet(self, snippet: Snippet) -> None:
117+
"""Add a single snippet to the snippet index.
118+
119+
Args:
120+
snippet (Snippet): The code snippet to be added.
121+
122+
Raises:
123+
UnknownService: If the service indicated by the snippet metadata is not found.
124+
RpcMethodNotFound: If the method indicated by the snippet metadata is not found.
125+
"""
126+
service_name = snippet.metadata.client_method.method.service.short_name
127+
rpc_name = snippet.metadata.client_method.method.full_name
128+
129+
service = self._index.get(service_name)
130+
if service is None:
131+
raise types.UnknownService(
132+
"API does not have a service named '{}'.".format(service_name))
133+
134+
method = service.get(rpc_name)
135+
if method is None:
136+
raise types.RpcMethodNotFound(
137+
"API does not have method '{}' in service '{}'".format(rpc_name, service_name))
138+
139+
if getattr(snippet.metadata.client_method, "async"):
140+
method["async"] = snippet
141+
else:
142+
method["sync"] = snippet
143+
144+
self.metadata_index.snippets.append(snippet.metadata)
145+
146+
def get_snippet(self, service_name: str, rpc_name: str, sync: bool = True) -> Optional[Snippet]:
147+
"""Fetch a single snippet from the index.
148+
149+
Args:
150+
service_name (str): The name of the service.
151+
rpc_name (str): The name of the RPC.
152+
sync (bool): True for the sync version of the snippet, False for the async version.
153+
154+
Returns:
155+
Optional[Snippet]: The snippet if it exists, or None.
156+
157+
Raises:
158+
UnknownService: If the service is not found.
159+
RpcMethodNotFound: If the method is not found.
160+
"""
161+
# Fetch a snippet from the snippet metadata index
162+
service = self._index.get(service_name)
163+
if service is None:
164+
raise types.UnknownService(
165+
"API does not have a service named '{}'.".format(service_name))
166+
method = service.get(rpc_name)
167+
if method is None:
168+
raise types.RpcMethodNotFound(
169+
"API does not have method '{}' in service '{}'".format(rpc_name, service_name))
170+
171+
return method["sync" if sync else "async"]
172+
173+
def get_metadata_json(self) -> str:
174+
"""JSON representation of Snippet Index."""
175+
return json_format.MessageToJson(self.metadata_index, sort_keys=True)

0 commit comments

Comments
 (0)