Skip to content

Commit ca7d979

Browse files
SpEcHiDeDanipulokKurimuzonAkuma
committed
(fix): Change raw/base docs types generation to get proper typehints
(KurimuzonAkuma/pyrogram#151) Follow-Up: bf83f57 Co-authored-by: Danipulok <Danipulok@users.noreply.github.com> Co-authored-by: KurimuzonAkuma <KurimuzonAkuma@users.noreply.github.com>
1 parent e8ed29d commit ca7d979

File tree

1 file changed

+86
-9
lines changed

1 file changed

+86
-9
lines changed

compiler/docs/compiler.py

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
import re
2222
import shutil
2323

24+
from dataclasses import dataclass
25+
from typing import Literal, Optional
26+
2427
HOME = "compiler/docs"
2528
DESTINATION = "docs/source/telegram"
2629
PYROGRAM_API_DEST = "docs/source/api"
@@ -39,6 +42,64 @@ def snek(s: str):
3942
return re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", s).lower()
4043

4144

45+
def _extract_union_name(node: ast.AST) -> Optional[str]:
46+
"""Extract the name of a variable that is assigned a Union type.
47+
48+
:param node: The AST node to extract the variable name from.
49+
:return: The variable name if it is assigned a Union type, otherwise None.
50+
51+
>>> import ast
52+
>>> parsed_ast = ast.parse("User = Union[raw.types.UserEmpty]")
53+
>>> _extract_union_name(parsed_ast.body[0])
54+
'User'
55+
"""
56+
57+
# Check if the assigned value is a Union type
58+
if isinstance(node, ast.Assign) and isinstance(node.value, ast.Subscript):
59+
if isinstance(node.value.value, ast.Name) and node.value.value.id == "Union":
60+
# Extract variable name
61+
if isinstance(node.targets[0], ast.Name):
62+
return node.targets[0].id # Variable name
63+
64+
65+
def _extract_class_name(node: ast.AST) -> Optional[str]:
66+
"""Extract the name of a class.
67+
68+
:param node: The AST node to extract the class name from.
69+
:return: The class name if it is a class, otherwise None.
70+
71+
>>> import ast
72+
>>> parsed_ast = ast.parse("class User: pass")
73+
>>> _extract_class_name(parsed_ast.body[0])
74+
'User'
75+
"""
76+
77+
if isinstance(node, ast.ClassDef):
78+
return node.name # Class name
79+
80+
81+
NodeType = Literal["class", "union"]
82+
83+
84+
@dataclass
85+
class NodeInfo:
86+
name: str
87+
type: NodeType
88+
89+
90+
def parse_node_info(node: ast.AST) -> Optional[NodeInfo]:
91+
"""Parse an AST node and extract the class or variable name."""
92+
class_name = _extract_class_name(node)
93+
if class_name:
94+
return NodeInfo(name=class_name, type="class")
95+
96+
union_name = _extract_union_name(node)
97+
if union_name:
98+
return NodeInfo(name=union_name, type="union")
99+
100+
return None
101+
102+
42103
def generate(source_path, base):
43104
all_entities = {}
44105

@@ -54,13 +115,13 @@ def build(path, level=0):
54115
p = ast.parse(f.read())
55116

56117
for node in ast.walk(p):
57-
if isinstance(node, ast.ClassDef):
58-
name = node.name
118+
node_info = parse_node_info(node)
119+
if node_info:
59120
break
60121
else:
61122
continue
62123

63-
full_path = os.path.basename(path) + "/" + snek(name).replace("_", "-") + ".rst"
124+
full_path = os.path.basename(path) + "/" + snek(node_info.name).replace("_", "-") + ".rst"
64125

65126
if level:
66127
full_path = base + "/" + full_path
@@ -69,25 +130,41 @@ def build(path, level=0):
69130
if namespace in ["base", "types", "functions"]:
70131
namespace = ""
71132

72-
full_name = f"{(namespace + '.') if namespace else ''}{name}"
133+
full_name = f"{(namespace + '.') if namespace else ''}{node_info.name}"
73134

74135
os.makedirs(os.path.dirname(DESTINATION + "/" + full_path), exist_ok=True)
75136

76137
with open(DESTINATION + "/" + full_path, "w", encoding="utf-8") as f:
138+
title_markup = "=" * len(full_name)
139+
full_class_path = "pyrogram.raw.{}".format(
140+
".".join(full_path.split("/")[:-1]) + "." + node_info.name
141+
)
142+
if node_info.type == "class":
143+
directive_type = "autoclass"
144+
directive_suffix = "()"
145+
directive_option = "members"
146+
elif node_info.type == "union":
147+
directive_type = "autodata"
148+
directive_suffix = ""
149+
directive_option = "annotation"
150+
else:
151+
raise ValueError(f"Unknown node type: `{node_info.type}`")
152+
77153
f.write(
78154
page_template.format(
79155
title=full_name,
80-
title_markup="=" * len(full_name),
81-
full_class_path="pyrogram.raw.{}".format(
82-
".".join(full_path.split("/")[:-1]) + "." + name
83-
)
156+
title_markup=title_markup,
157+
directive_type=directive_type,
158+
full_class_path=full_class_path,
159+
directive_suffix=directive_suffix,
160+
directive_option=directive_option,
84161
)
85162
)
86163

87164
if last not in all_entities:
88165
all_entities[last] = []
89166

90-
all_entities[last].append(name)
167+
all_entities[last].append(node_info.name)
91168

92169
build(source_path)
93170

0 commit comments

Comments
 (0)