From 685d1b16831d10904c3eb4b56230312136ca4eec Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?=
<10796600+picnixz@users.noreply.github.com>
Date: Sun, 28 May 2023 18:52:14 +0200
Subject: [PATCH] Add partial support for PEP 695 and PEP 696 syntax (#11438)
---
CHANGES | 4 +
sphinx/addnodes.py | 11 +-
sphinx/domains/python.py | 154 +++++++++++++++++++++++++-
sphinx/writers/html5.py | 6 +-
sphinx/writers/manpage.py | 4 +-
sphinx/writers/texinfo.py | 4 +-
sphinx/writers/text.py | 4 +-
tests/test_domain_py.py | 227 +++++++++++++++++++++++++++++++++++++-
8 files changed, 402 insertions(+), 12 deletions(-)
diff --git a/CHANGES b/CHANGES
index e66f11cbf7e..95def0f3059 100644
--- a/CHANGES
+++ b/CHANGES
@@ -18,6 +18,10 @@ Deprecated
Features added
--------------
+* #11438: Add support for the :rst:dir:`py:class` and :rst:dir:`py:function`
+ directives for PEP 695 (generic classes and functions declarations) and
+ PEP 696 (default type parameters).
+ Patch by Bénédikt Tran.
* #11415: Add a checksum to JavaScript and CSS asset URIs included within
generated HTML, using the CRC32 algorithm.
* :meth:`~sphinx.application.Sphinx.require_sphinx` now allows the version
diff --git a/sphinx/addnodes.py b/sphinx/addnodes.py
index e92d32a0ef8..5f9daea88a8 100644
--- a/sphinx/addnodes.py
+++ b/sphinx/addnodes.py
@@ -253,9 +253,17 @@ class desc_parameterlist(nodes.Part, nodes.Inline, nodes.FixedTextElement):
In that case each parameter will then be written on its own, indented line.
"""
child_text_separator = ', '
+ list_left_delim = '('
+ list_right_delim = ')'
def astext(self):
- return f'({super().astext()})'
+ return f'{self.list_left_delim}{super().astext()}{self.list_right_delim}'
+
+
+class desc_tparameterlist(desc_parameterlist):
+ """Node for a general type parameter list."""
+ list_left_delim = '['
+ list_right_delim = ']'
class desc_parameter(nodes.Part, nodes.Inline, nodes.FixedTextElement):
@@ -537,6 +545,7 @@ def setup(app: Sphinx) -> dict[str, Any]:
app.add_node(desc_type)
app.add_node(desc_returns)
app.add_node(desc_parameterlist)
+ app.add_node(desc_tparameterlist)
app.add_node(desc_parameter)
app.add_node(desc_optional)
app.add_node(desc_annotation)
diff --git a/sphinx/domains/python.py b/sphinx/domains/python.py
index 3fda5270351..6c99766afaf 100644
--- a/sphinx/domains/python.py
+++ b/sphinx/domains/python.py
@@ -23,6 +23,7 @@
from sphinx.domains import Domain, Index, IndexEntry, ObjType
from sphinx.environment import BuildEnvironment
from sphinx.locale import _, __
+from sphinx.pycode.parser import Token, TokenProcessor
from sphinx.roles import XRefRole
from sphinx.util import logging
from sphinx.util.docfields import Field, GroupedField, TypedField
@@ -39,10 +40,11 @@
logger = logging.getLogger(__name__)
-# REs for Python signatures
+# REs for Python signatures (supports PEP 695)
py_sig_re = re.compile(
r'''^ ([\w.]*\.)? # class name(s)
(\w+) \s* # thing name
+ (?: \[\s*(.*)\s*])? # optional: type parameters list (PEP 695)
(?: \(\s*(.*)\s*\) # optional: arguments
(?:\s* -> \s* (.*))? # return annotation
)? $ # and nothing more
@@ -257,6 +259,146 @@ def _unparse_pep_604_annotation(node: ast.Subscript) -> list[Node]:
return [type_to_xref(annotation, env)]
+class _TypeParameterListParser(TokenProcessor):
+ def __init__(self, sig: str) -> None:
+ from token import ERRORTOKEN
+
+ # By default, tokenizing "[T = dict[str, Any]]" gives "dict[str,Any]"
+ # instead of dict[str, Any]. In particular, the default parameter value
+ # will not be formatted properly. Therefore, whitespaces are replaced
+ # by some sentinel (bad) value that can be detected as an ERRORTOKEN.
+ self.virtual_ws = '\x00'
+ self.virtual_ws_tok = [ERRORTOKEN, self.virtual_ws]
+
+ signature = ''.join(sig.splitlines()).replace(' ', self.virtual_ws)
+ super().__init__([signature])
+ # Each item is a tuple (name, kind, default, bound) mimicking
+ # inspect.Parameter to allow default values on VAR_POSITIONAL
+ # or VAR_KEYWORD parameters.
+ self.tparams: list[tuple[str, int, Any, Any]] = []
+ # When true, (leading) whitespaces are dropped when fetching a token.
+ # Set it to false when parsing a type bound or a default value so that
+ # they are properly rendered.
+ self.ignore_ws = False
+
+ def fetch_token(self) -> Token | None:
+ if not self.ignore_ws:
+ return super().fetch_token()
+
+ while super().fetch_token() == self.virtual_ws_tok:
+ assert self.current
+ return self.current
+
+ def fetch_tparam_spec(self) -> list[Token]:
+ from token import DEDENT, INDENT, OP
+
+ tokens = []
+ self.ignore_ws = False
+ while self.fetch_token():
+ tokens.append(self.current)
+ for ldelim, rdelim in ['()', '{}', '[]']:
+ if self.current == [OP, ldelim]:
+ tokens += self.fetch_until([OP, rdelim])
+ break
+ else:
+ if self.current == INDENT:
+ tokens += self.fetch_until(DEDENT)
+ elif self.current.match([OP, ':'], [OP, '='], [OP, ',']):
+ tokens.pop()
+ break
+ self.ignore_ws = True
+ return tokens
+
+ def parse(self) -> None:
+ from token import NAME, OP
+
+ def build_identifier(tokens: Iterable[Token]) -> str:
+ ws = self.virtual_ws_tok
+ ident = ''.join(' ' if tok == ws else tok.value for tok in tokens)
+ return ident.strip()
+
+ while self.fetch_token():
+ if self.current == NAME:
+ tpname: list[str] = build_identifier([self.current])
+ if self.previous and self.previous.match([OP, '*'], [OP, '**']):
+ if self.previous == [OP, '*']:
+ tpkind = Parameter.VAR_POSITIONAL
+ else:
+ tpkind = Parameter.VAR_KEYWORD
+ else:
+ tpkind = Parameter.POSITIONAL_OR_KEYWORD
+
+ tpbound, tpdefault = Parameter.empty, Parameter.empty
+
+ self.fetch_token() # whitespaces (before) will be ignored
+ if self.current and self.current.match([OP, ':'], [OP, '=']):
+ if self.current == [OP, ':']:
+ tpbound = build_identifier(self.fetch_tparam_spec())
+ if self.current == [OP, '=']:
+ tpdefault = build_identifier(self.fetch_tparam_spec())
+
+ if tpkind != Parameter.POSITIONAL_OR_KEYWORD and tpbound != Parameter.empty:
+ raise SyntaxError('type parameter bound or constraint is not allowed '
+ f'for {tpkind.description} parameters')
+
+ tparam = (tpname, tpkind, tpdefault, tpbound)
+ self.tparams.append(tparam)
+
+
+def _parse_tplist(
+ tplist: str, env: BuildEnvironment | None = None,
+ multi_line_parameter_list: bool = False,
+) -> addnodes.desc_tparameterlist:
+ """Parse a list of type parameters according to PEP 695."""
+ tparams = addnodes.desc_tparameterlist(tplist)
+ tparams['multi_line_parameter_list'] = multi_line_parameter_list
+ # formal parameter names are interpreted as type parameter names and
+ # type annotations are interpreted as type parameter bounds
+ parser = _TypeParameterListParser(tplist)
+ parser.parse()
+ for (tpname, tpkind, tpdefault, tpbound) in parser.tparams:
+ # no positional-only or keyword-only allowed in a type parameters list
+ assert tpkind not in {Parameter.POSITIONAL_ONLY, Parameter.KEYWORD_ONLY}
+
+ node = addnodes.desc_parameter()
+ if tpkind == Parameter.VAR_POSITIONAL:
+ node += addnodes.desc_sig_operator('', '*')
+ elif tpkind == Parameter.VAR_KEYWORD:
+ node += addnodes.desc_sig_operator('', '**')
+ node += addnodes.desc_sig_name('', tpname)
+
+ if tpbound is not Parameter.empty:
+ type_bound = _parse_annotation(tpbound, env)
+ if not type_bound:
+ continue
+
+ node += addnodes.desc_sig_punctuation('', ':')
+ node += addnodes.desc_sig_space()
+
+ type_bound_expr = addnodes.desc_sig_name('', '', *type_bound) # type: ignore
+
+ # add delimiters around type bounds written as e.g., "(T1, T2)"
+ if tpbound.startswith('(') and tpbound.endswith(')'):
+ node += addnodes.desc_sig_punctuation('', '(')
+ node += type_bound_expr
+ node += addnodes.desc_sig_punctuation('', ')')
+ else:
+ node += type_bound_expr
+
+ if tpdefault is not Parameter.empty:
+ if tpbound is not Parameter.empty or tpkind != Parameter.POSITIONAL_OR_KEYWORD:
+ node += addnodes.desc_sig_space()
+ node += addnodes.desc_sig_operator('', '=')
+ node += addnodes.desc_sig_space()
+ else:
+ node += addnodes.desc_sig_operator('', '=')
+ node += nodes.inline('', tpdefault, classes=['default_value'],
+ support_smartquotes=False)
+
+ tparams += node
+ return tparams
+
+
def _parse_arglist(
arglist: str, env: BuildEnvironment | None = None, multi_line_parameter_list: bool = False,
) -> addnodes.desc_parameterlist:
@@ -514,7 +656,7 @@ def handle_signature(self, sig: str, signode: desc_signature) -> tuple[str, str]
m = py_sig_re.match(sig)
if m is None:
raise ValueError
- prefix, name, arglist, retann = m.groups()
+ prefix, name, tplist, arglist, retann = m.groups()
# determine module and class name (if applicable), as well as full name
modname = self.options.get('module', self.env.ref_context.get('py:module'))
@@ -570,6 +712,14 @@ def handle_signature(self, sig: str, signode: desc_signature) -> tuple[str, str]
signode += addnodes.desc_addname(nodetext, nodetext)
signode += addnodes.desc_name(name, name)
+
+ if tplist:
+ try:
+ signode += _parse_tplist(tplist, self.env, multi_line_parameter_list)
+ except Exception as exc:
+ logger.warning("could not parse tplist (%r): %s", tplist, exc,
+ location=signode)
+
if arglist:
try:
signode += _parse_arglist(arglist, self.env, multi_line_parameter_list)
diff --git a/sphinx/writers/html5.py b/sphinx/writers/html5.py
index e7d932286c5..14198b66ed0 100644
--- a/sphinx/writers/html5.py
+++ b/sphinx/writers/html5.py
@@ -149,7 +149,8 @@ def depart_desc_returns(self, node: Element) -> None:
self.body.append('')
def visit_desc_parameterlist(self, node: Element) -> None:
- self.body.append('(')
+ list_left_delim = node.list_left_delim # type: ignore[attr-defined]
+ self.body.append(f'{list_left_delim}')
self.is_first_param = True
self.optional_param_level = 0
self.params_left_at_level = 0
@@ -170,7 +171,8 @@ def visit_desc_parameterlist(self, node: Element) -> None:
def depart_desc_parameterlist(self, node: Element) -> None:
if node.get('multi_line_parameter_list'):
self.body.append('\n\n')
- self.body.append(')')
+ list_right_delim = node.list_right_delim # type: ignore[attr-defined]
+ self.body.append(f'{list_right_delim}')
# If required parameters are still to come, then put the comma after
# the parameter. Otherwise, put the comma before. This ensures that
diff --git a/sphinx/writers/manpage.py b/sphinx/writers/manpage.py
index 1e57f48addc..66b3ffa6e92 100644
--- a/sphinx/writers/manpage.py
+++ b/sphinx/writers/manpage.py
@@ -184,11 +184,11 @@ def depart_desc_returns(self, node: Element) -> None:
pass
def visit_desc_parameterlist(self, node: Element) -> None:
- self.body.append('(')
+ self.body.append(node.list_left_delim) # type: ignore[attr-defined]
self.first_param = 1
def depart_desc_parameterlist(self, node: Element) -> None:
- self.body.append(')')
+ self.body.append(node.list_right_delim) # type: ignore[attr-defined]
def visit_desc_parameter(self, node: Element) -> None:
if not self.first_param:
diff --git a/sphinx/writers/texinfo.py b/sphinx/writers/texinfo.py
index 927e74f3487..c421846613a 100644
--- a/sphinx/writers/texinfo.py
+++ b/sphinx/writers/texinfo.py
@@ -1462,11 +1462,11 @@ def depart_desc_returns(self, node: Element) -> None:
pass
def visit_desc_parameterlist(self, node: Element) -> None:
- self.body.append(' (')
+ self.body.append(f' {node.list_left_delim}') # type: ignore[attr-defined]
self.first_param = 1
def depart_desc_parameterlist(self, node: Element) -> None:
- self.body.append(')')
+ self.body.append(node.list_right_delim) # type: ignore[attr-defined]
def visit_desc_parameter(self, node: Element) -> None:
if not self.first_param:
diff --git a/sphinx/writers/text.py b/sphinx/writers/text.py
index 8e3d9df240d..616151a2ea7 100644
--- a/sphinx/writers/text.py
+++ b/sphinx/writers/text.py
@@ -593,7 +593,7 @@ def depart_desc_returns(self, node: Element) -> None:
pass
def visit_desc_parameterlist(self, node: Element) -> None:
- self.add_text('(')
+ self.add_text(node.list_left_delim) # type: ignore[attr-defined]
self.is_first_param = True
self.optional_param_level = 0
self.params_left_at_level = 0
@@ -609,7 +609,7 @@ def visit_desc_parameterlist(self, node: Element) -> None:
self.param_separator = self.param_separator.rstrip()
def depart_desc_parameterlist(self, node: Element) -> None:
- self.add_text(')')
+ self.add_text(node.list_right_delim) # type: ignore[attr-defined]
def visit_desc_parameter(self, node: Element) -> None:
on_separate_line = self.multi_line_parameter_list
diff --git a/tests/test_domain_py.py b/tests/test_domain_py.py
index 2b84f01c00d..4bbfe02b75a 100644
--- a/tests/test_domain_py.py
+++ b/tests/test_domain_py.py
@@ -1,5 +1,7 @@
"""Tests the Python Domain"""
+from __future__ import annotations
+
import re
from unittest.mock import Mock
@@ -26,6 +28,7 @@
desc_sig_punctuation,
desc_sig_space,
desc_signature,
+ desc_tparameterlist,
pending_xref,
)
from sphinx.domains import IndexEntry
@@ -45,7 +48,7 @@ def parse(sig):
m = py_sig_re.match(sig)
if m is None:
raise ValueError
- name_prefix, name, arglist, retann = m.groups()
+ name_prefix, generics, name, arglist, retann = m.groups()
signode = addnodes.desc_signature(sig, '')
_pseudo_parse_arglist(signode, arglist)
return signode.astext()
@@ -1840,3 +1843,225 @@ def test_short_literal_types(app):
[desc_content, ()],
)],
))
+
+
+def test_function_pep_695(app):
+ text = """.. py:function:: func[\
+ S,\
+ T: int,\
+ U: (int, str),\
+ R: int | int,\
+ A: int | Annotated[int, ctype("char")],\
+ *V,\
+ **P\
+ ]
+ """
+ doctree = restructuredtext.parse(app, text)
+ assert_node(doctree, (
+ addnodes.index,
+ [desc, (
+ [desc_signature, (
+ [desc_name, 'func'],
+ [desc_tparameterlist, (
+ [desc_parameter, ([desc_sig_name, 'S'])],
+ [desc_parameter, (
+ [desc_sig_name, 'T'],
+ [desc_sig_punctuation, ':'],
+ desc_sig_space,
+ [desc_sig_name, ([pending_xref, 'int'])],
+ )],
+ [desc_parameter, (
+ [desc_sig_name, 'U'],
+ [desc_sig_punctuation, ':'],
+ desc_sig_space,
+ [desc_sig_punctuation, '('],
+ [desc_sig_name, (
+ [pending_xref, 'int'],
+ [desc_sig_punctuation, ','],
+ desc_sig_space,
+ [pending_xref, 'str'],
+ )],
+ [desc_sig_punctuation, ')'],
+ )],
+ [desc_parameter, (
+ [desc_sig_name, 'R'],
+ [desc_sig_punctuation, ':'],
+ desc_sig_space,
+ [desc_sig_name, (
+ [pending_xref, 'int'],
+ desc_sig_space,
+ [desc_sig_punctuation, '|'],
+ desc_sig_space,
+ [pending_xref, 'int'],
+ )],
+ )],
+ [desc_parameter, (
+ [desc_sig_name, 'A'],
+ [desc_sig_punctuation, ':'],
+ desc_sig_space,
+ [desc_sig_name, ([pending_xref, 'int | Annotated[int, ctype("char")]'])],
+ )],
+ [desc_parameter, (
+ [desc_sig_operator, '*'],
+ [desc_sig_name, 'V'],
+ )],
+ [desc_parameter, (
+ [desc_sig_operator, '**'],
+ [desc_sig_name, 'P'],
+ )],
+ )],
+ [desc_parameterlist, ()],
+ )],
+ [desc_content, ()],
+ )],
+ ))
+
+
+def test_class_def_pep_695(app):
+ # Non-concrete unbound generics are allowed at runtime but type checkers
+ # should fail (https://peps.python.org/pep-0695/#type-parameter-scopes)
+ text = """.. py:class:: Class[S: Sequence[T], T, KT, VT](Dict[KT, VT])"""
+ doctree = restructuredtext.parse(app, text)
+ assert_node(doctree, (
+ addnodes.index,
+ [desc, (
+ [desc_signature, (
+ [desc_annotation, ('class', desc_sig_space)],
+ [desc_name, 'Class'],
+ [desc_tparameterlist, (
+ [desc_parameter, (
+ [desc_sig_name, 'S'],
+ [desc_sig_punctuation, ':'],
+ desc_sig_space,
+ [desc_sig_name, (
+ [pending_xref, 'Sequence'],
+ [desc_sig_punctuation, '['],
+ [pending_xref, 'T'],
+ [desc_sig_punctuation, ']'],
+ )],
+ )],
+ [desc_parameter, ([desc_sig_name, 'T'])],
+ [desc_parameter, ([desc_sig_name, 'KT'])],
+ [desc_parameter, ([desc_sig_name, 'VT'])],
+ )],
+ [desc_parameterlist, ([desc_parameter, 'Dict[KT, VT]'])],
+ )],
+ [desc_content, ()],
+ )],
+ ))
+
+def test_class_def_pre_696(app):
+ # test default values for type variables without using PEP 696
+ text = """.. py:class:: Class[\
+ T, KT, VT,\
+ J: int,\
+ S: str = str,\
+ L: (T, tuple[T, ...], collections.abc.Iterable[T]) = set[T],\
+ Q: collections.abc.Mapping[KT, VT] = dict[KT, VT],\
+ *V = *tuple[*Ts, bool],\
+ **P = [int, Annotated[int, ValueRange(3, 10), ctype("char")]]\
+ ](Other[T, KT, VT, J, S, L, Q, *V, **P])
+ """
+ doctree = restructuredtext.parse(app, text)
+ assert_node(doctree, (
+ addnodes.index,
+ [desc, (
+ [desc_signature, (
+ [desc_annotation, ('class', desc_sig_space)],
+ [desc_name, 'Class'],
+ [desc_tparameterlist, (
+ [desc_parameter, ([desc_sig_name, 'T'])],
+ [desc_parameter, ([desc_sig_name, 'KT'])],
+ [desc_parameter, ([desc_sig_name, 'VT'])],
+ # J: int
+ [desc_parameter, (
+ [desc_sig_name, 'J'],
+ [desc_sig_punctuation, ':'],
+ desc_sig_space,
+ [desc_sig_name, ([pending_xref, 'int'])],
+ )],
+ # S: str = str
+ [desc_parameter, (
+ [desc_sig_name, 'S'],
+ [desc_sig_punctuation, ':'],
+ desc_sig_space,
+ [desc_sig_name, ([pending_xref, 'str'])],
+ desc_sig_space,
+ [desc_sig_operator, '='],
+ desc_sig_space,
+ [nodes.inline, 'str'],
+ )],
+ [desc_parameter, (
+ [desc_sig_name, 'L'],
+ [desc_sig_punctuation, ':'],
+ desc_sig_space,
+ [desc_sig_punctuation, '('],
+ [desc_sig_name, (
+ # T
+ [pending_xref, 'T'],
+ [desc_sig_punctuation, ','],
+ desc_sig_space,
+ # tuple[T, ...]
+ [pending_xref, 'tuple'],
+ [desc_sig_punctuation, '['],
+ [pending_xref, 'T'],
+ [desc_sig_punctuation, ','],
+ desc_sig_space,
+ [desc_sig_punctuation, '...'],
+ [desc_sig_punctuation, ']'],
+ [desc_sig_punctuation, ','],
+ desc_sig_space,
+ # collections.abc.Iterable[T]
+ [pending_xref, 'collections.abc.Iterable'],
+ [desc_sig_punctuation, '['],
+ [pending_xref, 'T'],
+ [desc_sig_punctuation, ']'],
+ )],
+ [desc_sig_punctuation, ')'],
+ desc_sig_space,
+ [desc_sig_operator, '='],
+ desc_sig_space,
+ [nodes.inline, 'set[T]'],
+ )],
+ [desc_parameter, (
+ [desc_sig_name, 'Q'],
+ [desc_sig_punctuation, ':'],
+ desc_sig_space,
+ [desc_sig_name, (
+ [pending_xref, 'collections.abc.Mapping'],
+ [desc_sig_punctuation, '['],
+ [pending_xref, 'KT'],
+ [desc_sig_punctuation, ','],
+ desc_sig_space,
+ [pending_xref, 'VT'],
+ [desc_sig_punctuation, ']'],
+ )],
+ desc_sig_space,
+ [desc_sig_operator, '='],
+ desc_sig_space,
+ [nodes.inline, 'dict[KT, VT]'],
+ )],
+ [desc_parameter, (
+ [desc_sig_operator, '*'],
+ [desc_sig_name, 'V'],
+ desc_sig_space,
+ [desc_sig_operator, '='],
+ desc_sig_space,
+ [nodes.inline, '*tuple[*Ts, bool]'],
+ )],
+ [desc_parameter, (
+ [desc_sig_operator, '**'],
+ [desc_sig_name, 'P'],
+ desc_sig_space,
+ [desc_sig_operator, '='],
+ desc_sig_space,
+ [nodes.inline, '[int, Annotated[int, ValueRange(3, 10), ctype("char")]]'],
+ )],
+ )],
+ [desc_parameterlist, (
+ [desc_parameter, 'Other[T, KT, VT, J, S, L, Q, *V, **P]'],
+ )],
+ )],
+ [desc_content, ()],
+ )],
+ ))