diff --git a/libcst/_parser/detect_config.py b/libcst/_parser/detect_config.py index 1d5c171d0..b4ce02f2f 100644 --- a/libcst/_parser/detect_config.py +++ b/libcst/_parser/detect_config.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from io import BytesIO from tokenize import detect_encoding as py_tokenize_detect_encoding -from typing import Iterable, Iterator, Pattern, Set, Union +from typing import FrozenSet, Iterable, Iterator, Pattern, Set, Union from libcst._nodes.whitespace import NEWLINE_RE from libcst._parser.parso.python.token import PythonTokenTypes, TokenType @@ -84,7 +84,7 @@ def _detect_trailing_newline(source_str: str) -> bool: ) -def _detect_future_imports(tokens: Iterable[Token]) -> Set[str]: +def _detect_future_imports(tokens: Iterable[Token]) -> FrozenSet[str]: """ Finds __future__ imports in their proper locations. @@ -113,7 +113,7 @@ def _detect_future_imports(tokens: Iterable[Token]) -> Set[str]: state = 0 else: break - return future_imports + return frozenset(future_imports) def detect_config( diff --git a/libcst/_parser/entrypoints.py b/libcst/_parser/entrypoints.py index ddd50aa68..6e0b54937 100644 --- a/libcst/_parser/entrypoints.py +++ b/libcst/_parser/entrypoints.py @@ -40,7 +40,7 @@ def _parse( detect_default_newline=detect_default_newline, ) validate_grammar() - grammar = get_grammar(config.parsed_python_version) + grammar = get_grammar(config.parsed_python_version, config.future_imports) parser = PythonCSTParser( tokens=detection_result.tokens, diff --git a/libcst/_parser/grammar.py b/libcst/_parser/grammar.py index 987dceb13..2123f738d 100644 --- a/libcst/_parser/grammar.py +++ b/libcst/_parser/grammar.py @@ -6,7 +6,7 @@ # pyre-strict import re from functools import lru_cache -from typing import Iterator, Mapping, Optional, Tuple +from typing import FrozenSet, Iterator, Mapping, Optional, Tuple, Union from libcst._parser.conversions.expression import ( convert_arg_assign_comp_for, @@ -138,6 +138,7 @@ from libcst._parser.parso.python.token import PythonTokenTypes, TokenType from libcst._parser.parso.utils import PythonVersionInfo, parse_version_string from libcst._parser.production_decorator import get_productions +from libcst._parser.types.config import AutoConfig from libcst._parser.types.conversions import NonterminalConversion, TerminalConversion from libcst._parser.types.production import Production @@ -269,7 +270,7 @@ ) -def get_grammar_str(version: PythonVersionInfo) -> str: +def get_grammar_str(version: PythonVersionInfo, future_imports: FrozenSet[str]) -> str: """ Returns an BNF-like grammar text that `parso.pgen2.generator.generate_grammar` can handle. @@ -278,7 +279,7 @@ def get_grammar_str(version: PythonVersionInfo) -> str: debugging the grammar. """ lines = [] - for p in get_nonterminal_productions(version): + for p in get_nonterminal_productions(version, future_imports): lines.append(str(p)) return "\n".join(lines) + "\n" @@ -287,8 +288,13 @@ def get_grammar_str(version: PythonVersionInfo) -> str: # of how we're defining our grammar, efficient cache invalidation is harder, though not # impossible. @lru_cache() -def get_grammar(version: PythonVersionInfo) -> "Grammar[TokenType]": - return generate_grammar(get_grammar_str(version), PythonTokenTypes) +def get_grammar( + version: PythonVersionInfo, future_imports: Union[FrozenSet[str], AutoConfig], +) -> "Grammar[TokenType]": + if isinstance(future_imports, AutoConfig): + # For easier testing, if not provided assume no __future__ imports + future_imports = frozenset(()) + return generate_grammar(get_grammar_str(version, future_imports), PythonTokenTypes) @lru_cache() @@ -360,16 +366,31 @@ def _should_include( return True -def get_nonterminal_productions(version: PythonVersionInfo) -> Iterator[Production]: +def _should_include_future( + future: Optional[str], future_imports: FrozenSet[str], +) -> bool: + if future is None: + return True + if future[:1] == "!": + return future[1:] not in future_imports + return future in future_imports + + +def get_nonterminal_productions( + version: PythonVersionInfo, future_imports: FrozenSet[str] +) -> Iterator[Production]: for conversion in _NONTERMINAL_CONVERSIONS_SEQUENCE: for production in get_productions(conversion): - if _should_include(production.version, version): - yield production + if not _should_include(production.version, version): + continue + if not _should_include_future(production.future, future_imports): + continue + yield production @lru_cache() def get_nonterminal_conversions( - version: PythonVersionInfo, + version: PythonVersionInfo, future_imports: FrozenSet[str], ) -> Mapping[str, NonterminalConversion]: """ Returns a mapping from nonterminal production name to the conversion function that @@ -380,6 +401,8 @@ def get_nonterminal_conversions( for fn_production in get_productions(fn): if not _should_include(fn_production.version, version): continue + if not _should_include_future(fn_production.future, future_imports): + continue if fn_production.name in conversions: raise Exception( f"Found duplicate '{fn_production.name}' production in grammar" diff --git a/libcst/_parser/production_decorator.py b/libcst/_parser/production_decorator.py index 9860152a2..804ce0c81 100644 --- a/libcst/_parser/production_decorator.py +++ b/libcst/_parser/production_decorator.py @@ -18,7 +18,11 @@ # We could version our grammar at a later point by adding a version metadata kwarg to # this decorator. def with_production( - production_name: str, children: str, *, version: Optional[str] = None + production_name: str, + children: str, + *, + version: Optional[str] = None, + future: Optional[str] = None, ) -> Callable[[_NonterminalConversionT], _NonterminalConversionT]: """ Attaches a bit of grammar to a conversion function. The parser extracts all of these @@ -38,7 +42,7 @@ def inner(fn: _NonterminalConversionT) -> _NonterminalConversionT: + f"'{fn_name}'." ) # pyre-ignore: Pyre doesn't know about this magic field we added - fn.productions.append(Production(production_name, children, version)) + fn.productions.append(Production(production_name, children, version, future)) return fn return inner diff --git a/libcst/_parser/python_parser.py b/libcst/_parser/python_parser.py index e5208b83d..3a05072be 100644 --- a/libcst/_parser/python_parser.py +++ b/libcst/_parser/python_parser.py @@ -35,7 +35,9 @@ def __init__( ) self.config = config self.terminal_conversions = get_terminal_conversions() - self.nonterminal_conversions = get_nonterminal_conversions(config.version) + self.nonterminal_conversions = get_nonterminal_conversions( + config.version, config.future_imports + ) def convert_nonterminal(self, nonterminal: str, children: Sequence[Any]) -> Any: return self.nonterminal_conversions[nonterminal](self.config, children) diff --git a/libcst/_parser/tests/test_detect_config.py b/libcst/_parser/tests/test_detect_config.py index c2cb0430c..aa6024c61 100644 --- a/libcst/_parser/tests/test_detect_config.py +++ b/libcst/_parser/tests/test_detect_config.py @@ -28,7 +28,7 @@ class TestDetectConfig(UnitTest): default_newline="\n", has_trailing_newline=False, version=PythonVersionInfo(3, 7), - future_imports=set(), + future_imports=frozenset(), ), }, "detect_trailing_newline_disabled": { @@ -43,7 +43,7 @@ class TestDetectConfig(UnitTest): default_newline="\n", has_trailing_newline=False, version=PythonVersionInfo(3, 7), - future_imports=set(), + future_imports=frozenset(), ), }, "detect_default_newline_disabled": { @@ -58,7 +58,7 @@ class TestDetectConfig(UnitTest): default_newline="\n", has_trailing_newline=False, version=PythonVersionInfo(3, 7), - future_imports=set(), + future_imports=frozenset(), ), }, "newline_inferred": { @@ -73,7 +73,7 @@ class TestDetectConfig(UnitTest): default_newline="\r\n", has_trailing_newline=True, version=PythonVersionInfo(3, 7), - future_imports=set(), + future_imports=frozenset(), ), }, "newline_partial_given": { @@ -90,7 +90,7 @@ class TestDetectConfig(UnitTest): default_newline="\n", # The given partial disables inference has_trailing_newline=True, version=PythonVersionInfo(3, 7), - future_imports=set(), + future_imports=frozenset(), ), }, "indent_inferred": { @@ -105,7 +105,7 @@ class TestDetectConfig(UnitTest): default_newline="\n", has_trailing_newline=True, version=PythonVersionInfo(3, 7), - future_imports=set(), + future_imports=frozenset(), ), }, "indent_partial_given": { @@ -122,7 +122,7 @@ class TestDetectConfig(UnitTest): default_newline="\n", has_trailing_newline=True, version=PythonVersionInfo(3, 7), - future_imports=set(), + future_imports=frozenset(), ), }, "encoding_inferred": { @@ -142,7 +142,7 @@ class TestDetectConfig(UnitTest): default_newline="\n", has_trailing_newline=True, version=PythonVersionInfo(3, 7), - future_imports=set(), + future_imports=frozenset(), ), }, "encoding_partial_given": { @@ -164,7 +164,7 @@ class TestDetectConfig(UnitTest): default_newline="\n", has_trailing_newline=True, version=PythonVersionInfo(3, 7), - future_imports=set(), + future_imports=frozenset(), ), }, "encoding_str_not_bytes_disables_inference": { @@ -184,7 +184,7 @@ class TestDetectConfig(UnitTest): default_newline="\n", has_trailing_newline=True, version=PythonVersionInfo(3, 7), - future_imports=set(), + future_imports=frozenset(), ), }, "encoding_non_ascii_compatible_utf_16_with_bom": { @@ -199,7 +199,7 @@ class TestDetectConfig(UnitTest): default_newline="\n", has_trailing_newline=False, version=PythonVersionInfo(3, 7), - future_imports=set(), + future_imports=frozenset(), ), }, "detect_trailing_newline_missing_newline": { @@ -214,7 +214,7 @@ class TestDetectConfig(UnitTest): default_newline="\n", has_trailing_newline=False, version=PythonVersionInfo(3, 7), - future_imports=set(), + future_imports=frozenset(), ), }, "detect_trailing_newline_has_newline": { @@ -229,7 +229,7 @@ class TestDetectConfig(UnitTest): default_newline="\n", has_trailing_newline=True, version=PythonVersionInfo(3, 7), - future_imports=set(), + future_imports=frozenset(), ), }, "detect_trailing_newline_missing_newline_after_line_continuation": { @@ -244,7 +244,7 @@ class TestDetectConfig(UnitTest): default_newline="\n", has_trailing_newline=False, version=PythonVersionInfo(3, 7), - future_imports=set(), + future_imports=frozenset(), ), }, "detect_trailing_newline_has_newline_after_line_continuation": { @@ -259,7 +259,7 @@ class TestDetectConfig(UnitTest): default_newline="\n", has_trailing_newline=True, version=PythonVersionInfo(3, 7), - future_imports=set(), + future_imports=frozenset(), ), }, "future_imports_in_correct_position": { @@ -279,7 +279,7 @@ class TestDetectConfig(UnitTest): default_newline="\n", has_trailing_newline=True, version=PythonVersionInfo(3, 7), - future_imports={"a"}, + future_imports=frozenset({"a"}), ), }, "future_imports_in_mixed_position": { @@ -302,7 +302,7 @@ class TestDetectConfig(UnitTest): default_newline="\n", has_trailing_newline=True, version=PythonVersionInfo(3, 7), - future_imports={"a", "b"}, + future_imports=frozenset({"a", "b"}), ), }, } diff --git a/libcst/_parser/types/config.py b/libcst/_parser/types/config.py index 08691d2fb..9bdcafc86 100644 --- a/libcst/_parser/types/config.py +++ b/libcst/_parser/types/config.py @@ -10,7 +10,7 @@ import re from dataclasses import dataclass, field, fields from enum import Enum -from typing import List, Pattern, Sequence, Set, Union +from typing import FrozenSet, List, Pattern, Sequence, Union from libcst._add_slots import add_slots from libcst._nodes.whitespace import NEWLINE_RE @@ -45,7 +45,7 @@ class ParserConfig(BaseWhitespaceParserConfig): default_newline: str has_trailing_newline: bool version: PythonVersionInfo - future_imports: Set[str] + future_imports: FrozenSet[str] class AutoConfig(Enum): @@ -98,7 +98,7 @@ class PartialParserConfig: encoding: Union[str, AutoConfig] = AutoConfig.token #: Detected ``__future__`` import names - future_imports: Union[Set[str], AutoConfig] = AutoConfig.token + future_imports: Union[FrozenSet[str], AutoConfig] = AutoConfig.token #: The indentation of the file, expressed as a series of tabs and/or spaces. This #: value is inferred from the contents of the parsed source code by default. diff --git a/libcst/_parser/types/production.py b/libcst/_parser/types/production.py index 18c4a60df..ad68d45c3 100644 --- a/libcst/_parser/types/production.py +++ b/libcst/_parser/types/production.py @@ -14,6 +14,7 @@ class Production: name: str children: str version: Optional[str] + future: Optional[str] def __str__(self) -> str: return f"{self.name}: {self.children}"