Skip to content

Commit b79ba08

Browse files
committed
Add type stub template to API generator
1 parent cd4c57c commit b79ba08

File tree

3 files changed

+86
-11
lines changed

3 files changed

+86
-11
lines changed

utils/generate_api.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,16 @@
6262
/ "rest-api-spec"
6363
/ "api"
6464
)
65+
GLOBAL_QUERY_PARAMS = {
66+
"pretty": "Optional[bool]",
67+
"human": "Optional[bool]",
68+
"error_trace": "Optional[bool]",
69+
"format": "Optional[str]",
70+
"filter_path": "Optional[Union[str, Collection[str]]]",
71+
"request_timeout": "Optional[Union[int, float]]",
72+
"ignore": "Optional[Union[int, Collection[int]]]",
73+
"opaque_id": "Optional[str]",
74+
}
6575

6676
jinja_env = Environment(
6777
loader=FileSystemLoader([CODE_ROOT / "utils" / "templates"]),
@@ -78,15 +88,20 @@ def blacken(filename):
7888

7989
@lru_cache()
8090
def is_valid_url(url):
81-
return http.request("HEAD", url).status == 200
91+
return 200 <= http.request("HEAD", url).status < 400
8292

8393

8494
class Module:
85-
def __init__(self, namespace):
95+
def __init__(self, namespace, is_pyi=False):
8696
self.namespace = namespace
97+
self.is_pyi = is_pyi
8798
self._apis = []
8899
self.parse_orig()
89100

101+
if not is_pyi:
102+
self.pyi = Module(namespace, is_pyi=True)
103+
self.pyi.orders = self.orders[:]
104+
90105
def add(self, api):
91106
self._apis.append(api)
92107

@@ -128,17 +143,23 @@ def dump(self):
128143
f.write(self.header)
129144
for api in self._apis:
130145
f.write(api.to_python())
131-
blacken(self.filepath)
146+
147+
if not self.is_pyi:
148+
self.pyi.dump()
132149

133150
@property
134151
def filepath(self):
135-
return CODE_ROOT / f"elasticsearch/_async/client/{self.namespace}.py"
152+
return (
153+
CODE_ROOT
154+
/ f"elasticsearch/_async/client/{self.namespace}.py{'i' if self.is_pyi else ''}"
155+
)
136156

137157

138158
class API:
139-
def __init__(self, namespace, name, definition):
159+
def __init__(self, namespace, name, definition, is_pyi=False):
140160
self.namespace = namespace
141161
self.name = name
162+
self.is_pyi = is_pyi
142163

143164
# overwrite the dict to maintain key order
144165
definition["params"] = {
@@ -187,6 +208,7 @@ def all_parts(self):
187208
parts[p]["required"] = all(
188209
p in url.get("parts", {}) for url in self._def["url"]["paths"]
189210
)
211+
parts[p]["type"] = "Any"
190212

191213
for k, sub in SUBSTITUTIONS.items():
192214
if k in parts:
@@ -233,6 +255,19 @@ def query_params(self):
233255
if k not in self.all_parts
234256
)
235257

258+
@property
259+
def all_func_params(self):
260+
"""Parameters that will be in the '@query_params' decorator list
261+
and parameters that will be in the function signature.
262+
This doesn't include
263+
"""
264+
params = list(self._def.get("params", {}).keys())
265+
for url in self._def["url"]["paths"]:
266+
params.extend(url.get("parts", {}).keys())
267+
if self.body:
268+
params.append("body")
269+
return params
270+
236271
@property
237272
def path(self):
238273
return max(
@@ -279,12 +314,18 @@ def required_parts(self):
279314
return required
280315

281316
def to_python(self):
282-
try:
283-
t = jinja_env.get_template(f"overrides/{self.namespace}/{self.name}")
284-
except TemplateNotFound:
285-
t = jinja_env.get_template("base")
317+
if self.is_pyi:
318+
t = jinja_env.get_template("base_pyi")
319+
else:
320+
try:
321+
t = jinja_env.get_template(f"overrides/{self.namespace}/{self.name}")
322+
except TemplateNotFound:
323+
t = jinja_env.get_template("base")
324+
286325
return t.render(
287-
api=self, substitutions={v: k for k, v in SUBSTITUTIONS.items()}
326+
api=self,
327+
substitutions={v: k for k, v in SUBSTITUTIONS.items()},
328+
global_query_params=GLOBAL_QUERY_PARAMS,
288329
)
289330

290331

@@ -313,6 +354,7 @@ def read_modules():
313354
modules[namespace] = Module(namespace)
314355

315356
modules[namespace].add(API(namespace, name, api))
357+
modules[namespace].pyi.add(API(namespace, name, api, is_pyi=True))
316358

317359
return modules
318360

@@ -340,7 +382,14 @@ def dump_modules(modules):
340382
filepaths = []
341383
for root, _, filenames in os.walk(CODE_ROOT / "elasticsearch/_async"):
342384
for filename in filenames:
343-
if filename.endswith(".py") and filename != "utils.py":
385+
if (
386+
filename.rpartition(".")[-1]
387+
in (
388+
"py",
389+
"pyi",
390+
)
391+
and not filename.startswith("utils.py")
392+
):
344393
filepaths.append(os.path.join(root, filename))
345394

346395
unasync.unasync_files(filepaths, rules)

utils/templates/base_pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
async def {{ api.name }}(self, {% include "func_params_pyi" %}) -> {% if api.method == 'HEAD' %}bool{% else %}Any{% endif %}: ...

utils/templates/func_params_pyi

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
{% for p, info in api.all_parts.items() %}
2+
{% if info.required %}{{ p }}: {{ info.type }}, {% endif %}
3+
{% endfor %}
4+
5+
{% if api.body %}
6+
body{% if not api.body.required %}: Optional[Any]=...{% else %}: Any{% endif %},
7+
{% endif %}
8+
9+
{% for p, info in api.all_parts.items() %}
10+
{% if not info.required %}{{ p }}: Optional[{{ info.type }}]=..., {% endif %}
11+
{% endfor %}
12+
13+
{% for p in api.query_params %}
14+
{{ p }}: Optional[Any]=...,
15+
{% endfor %}
16+
17+
{% for p, p_type in global_query_params.items() %}
18+
{% if p not in api.all_func_params %}
19+
{{ p }}: {{ p_type }}=...,
20+
{% endif %}
21+
{% endfor %}
22+
23+
params: Optional[Mapping[str, Any]]=...,
24+
headers: Optional[Mapping[str, str]]=...

0 commit comments

Comments
 (0)