Skip to content

Commit

Permalink
Plumb through future_imports
Browse files Browse the repository at this point in the history
  • Loading branch information
thatch committed Mar 12, 2020
1 parent 522eb5e commit 0c7d8b4
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 36 deletions.
6 changes: 3 additions & 3 deletions libcst/_parser/detect_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dataclasses import dataclass
from io import BytesIO
from tokenize import detect_encoding as py_tokenize_detect_encoding
from typing import Iterable, Iterator, Pattern, Set, Union
from typing import FrozenSet, Iterable, Iterator, Pattern, Set, Union

from libcst._nodes.whitespace import NEWLINE_RE
from libcst._parser.parso.python.token import PythonTokenTypes, TokenType
Expand Down Expand Up @@ -84,7 +84,7 @@ def _detect_trailing_newline(source_str: str) -> bool:
)


def _detect_future_imports(tokens: Iterable[Token]) -> Set[str]:
def _detect_future_imports(tokens: Iterable[Token]) -> FrozenSet[str]:
"""
Finds __future__ imports in their proper locations.
Expand Down Expand Up @@ -113,7 +113,7 @@ def _detect_future_imports(tokens: Iterable[Token]) -> Set[str]:
state = 0
else:
break
return future_imports
return frozenset(future_imports)


def detect_config(
Expand Down
2 changes: 1 addition & 1 deletion libcst/_parser/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _parse(
detect_default_newline=detect_default_newline,
)
validate_grammar()
grammar = get_grammar(config.parsed_python_version)
grammar = get_grammar(config.parsed_python_version, config.future_imports)

parser = PythonCSTParser(
tokens=detection_result.tokens,
Expand Down
41 changes: 32 additions & 9 deletions libcst/_parser/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# pyre-strict
import re
from functools import lru_cache
from typing import Iterator, Mapping, Optional, Tuple
from typing import FrozenSet, Iterator, Mapping, Optional, Tuple, Union

from libcst._parser.conversions.expression import (
convert_arg_assign_comp_for,
Expand Down Expand Up @@ -138,6 +138,7 @@
from libcst._parser.parso.python.token import PythonTokenTypes, TokenType
from libcst._parser.parso.utils import PythonVersionInfo, parse_version_string
from libcst._parser.production_decorator import get_productions
from libcst._parser.types.config import AutoConfig
from libcst._parser.types.conversions import NonterminalConversion, TerminalConversion
from libcst._parser.types.production import Production

Expand Down Expand Up @@ -269,7 +270,7 @@
)


def get_grammar_str(version: PythonVersionInfo) -> str:
def get_grammar_str(version: PythonVersionInfo, future_imports: FrozenSet[str]) -> str:
"""
Returns an BNF-like grammar text that `parso.pgen2.generator.generate_grammar` can
handle.
Expand All @@ -278,7 +279,7 @@ def get_grammar_str(version: PythonVersionInfo) -> str:
debugging the grammar.
"""
lines = []
for p in get_nonterminal_productions(version):
for p in get_nonterminal_productions(version, future_imports):
lines.append(str(p))
return "\n".join(lines) + "\n"

Expand All @@ -287,8 +288,13 @@ def get_grammar_str(version: PythonVersionInfo) -> str:
# of how we're defining our grammar, efficient cache invalidation is harder, though not
# impossible.
@lru_cache()
def get_grammar(version: PythonVersionInfo) -> "Grammar[TokenType]":
return generate_grammar(get_grammar_str(version), PythonTokenTypes)
def get_grammar(
version: PythonVersionInfo, future_imports: Union[FrozenSet[str], AutoConfig],
) -> "Grammar[TokenType]":
if isinstance(future_imports, AutoConfig):
# For easier testing, if not provided assume no __future__ imports
future_imports = frozenset(())
return generate_grammar(get_grammar_str(version, future_imports), PythonTokenTypes)


@lru_cache()
Expand Down Expand Up @@ -360,16 +366,31 @@ def _should_include(
return True


def get_nonterminal_productions(version: PythonVersionInfo) -> Iterator[Production]:
def _should_include_future(
future: Optional[str], future_imports: FrozenSet[str],
) -> bool:
if future is None:
return True
if future[:1] == "!":
return future[1:] not in future_imports
return future in future_imports


def get_nonterminal_productions(
version: PythonVersionInfo, future_imports: FrozenSet[str]
) -> Iterator[Production]:
for conversion in _NONTERMINAL_CONVERSIONS_SEQUENCE:
for production in get_productions(conversion):
if _should_include(production.version, version):
yield production
if not _should_include(production.version, version):
continue
if not _should_include_future(production.future, future_imports):
continue
yield production


@lru_cache()
def get_nonterminal_conversions(
version: PythonVersionInfo,
version: PythonVersionInfo, future_imports: FrozenSet[str],
) -> Mapping[str, NonterminalConversion]:
"""
Returns a mapping from nonterminal production name to the conversion function that
Expand All @@ -380,6 +401,8 @@ def get_nonterminal_conversions(
for fn_production in get_productions(fn):
if not _should_include(fn_production.version, version):
continue
if not _should_include_future(fn_production.future, future_imports):
continue
if fn_production.name in conversions:
raise Exception(
f"Found duplicate '{fn_production.name}' production in grammar"
Expand Down
8 changes: 6 additions & 2 deletions libcst/_parser/production_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
# We could version our grammar at a later point by adding a version metadata kwarg to
# this decorator.
def with_production(
production_name: str, children: str, *, version: Optional[str] = None
production_name: str,
children: str,
*,
version: Optional[str] = None,
future: Optional[str] = None,
) -> Callable[[_NonterminalConversionT], _NonterminalConversionT]:
"""
Attaches a bit of grammar to a conversion function. The parser extracts all of these
Expand All @@ -38,7 +42,7 @@ def inner(fn: _NonterminalConversionT) -> _NonterminalConversionT:
+ f"'{fn_name}'."
)
# pyre-ignore: Pyre doesn't know about this magic field we added
fn.productions.append(Production(production_name, children, version))
fn.productions.append(Production(production_name, children, version, future))
return fn

return inner
Expand Down
4 changes: 3 additions & 1 deletion libcst/_parser/python_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def __init__(
)
self.config = config
self.terminal_conversions = get_terminal_conversions()
self.nonterminal_conversions = get_nonterminal_conversions(config.version)
self.nonterminal_conversions = get_nonterminal_conversions(
config.version, config.future_imports
)

def convert_nonterminal(self, nonterminal: str, children: Sequence[Any]) -> Any:
return self.nonterminal_conversions[nonterminal](self.config, children)
Expand Down
34 changes: 17 additions & 17 deletions libcst/_parser/tests/test_detect_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TestDetectConfig(UnitTest):
default_newline="\n",
has_trailing_newline=False,
version=PythonVersionInfo(3, 7),
future_imports=set(),
future_imports=frozenset(),
),
},
"detect_trailing_newline_disabled": {
Expand All @@ -43,7 +43,7 @@ class TestDetectConfig(UnitTest):
default_newline="\n",
has_trailing_newline=False,
version=PythonVersionInfo(3, 7),
future_imports=set(),
future_imports=frozenset(),
),
},
"detect_default_newline_disabled": {
Expand All @@ -58,7 +58,7 @@ class TestDetectConfig(UnitTest):
default_newline="\n",
has_trailing_newline=False,
version=PythonVersionInfo(3, 7),
future_imports=set(),
future_imports=frozenset(),
),
},
"newline_inferred": {
Expand All @@ -73,7 +73,7 @@ class TestDetectConfig(UnitTest):
default_newline="\r\n",
has_trailing_newline=True,
version=PythonVersionInfo(3, 7),
future_imports=set(),
future_imports=frozenset(),
),
},
"newline_partial_given": {
Expand All @@ -90,7 +90,7 @@ class TestDetectConfig(UnitTest):
default_newline="\n", # The given partial disables inference
has_trailing_newline=True,
version=PythonVersionInfo(3, 7),
future_imports=set(),
future_imports=frozenset(),
),
},
"indent_inferred": {
Expand All @@ -105,7 +105,7 @@ class TestDetectConfig(UnitTest):
default_newline="\n",
has_trailing_newline=True,
version=PythonVersionInfo(3, 7),
future_imports=set(),
future_imports=frozenset(),
),
},
"indent_partial_given": {
Expand All @@ -122,7 +122,7 @@ class TestDetectConfig(UnitTest):
default_newline="\n",
has_trailing_newline=True,
version=PythonVersionInfo(3, 7),
future_imports=set(),
future_imports=frozenset(),
),
},
"encoding_inferred": {
Expand All @@ -142,7 +142,7 @@ class TestDetectConfig(UnitTest):
default_newline="\n",
has_trailing_newline=True,
version=PythonVersionInfo(3, 7),
future_imports=set(),
future_imports=frozenset(),
),
},
"encoding_partial_given": {
Expand All @@ -164,7 +164,7 @@ class TestDetectConfig(UnitTest):
default_newline="\n",
has_trailing_newline=True,
version=PythonVersionInfo(3, 7),
future_imports=set(),
future_imports=frozenset(),
),
},
"encoding_str_not_bytes_disables_inference": {
Expand All @@ -184,7 +184,7 @@ class TestDetectConfig(UnitTest):
default_newline="\n",
has_trailing_newline=True,
version=PythonVersionInfo(3, 7),
future_imports=set(),
future_imports=frozenset(),
),
},
"encoding_non_ascii_compatible_utf_16_with_bom": {
Expand All @@ -199,7 +199,7 @@ class TestDetectConfig(UnitTest):
default_newline="\n",
has_trailing_newline=False,
version=PythonVersionInfo(3, 7),
future_imports=set(),
future_imports=frozenset(),
),
},
"detect_trailing_newline_missing_newline": {
Expand All @@ -214,7 +214,7 @@ class TestDetectConfig(UnitTest):
default_newline="\n",
has_trailing_newline=False,
version=PythonVersionInfo(3, 7),
future_imports=set(),
future_imports=frozenset(),
),
},
"detect_trailing_newline_has_newline": {
Expand All @@ -229,7 +229,7 @@ class TestDetectConfig(UnitTest):
default_newline="\n",
has_trailing_newline=True,
version=PythonVersionInfo(3, 7),
future_imports=set(),
future_imports=frozenset(),
),
},
"detect_trailing_newline_missing_newline_after_line_continuation": {
Expand All @@ -244,7 +244,7 @@ class TestDetectConfig(UnitTest):
default_newline="\n",
has_trailing_newline=False,
version=PythonVersionInfo(3, 7),
future_imports=set(),
future_imports=frozenset(),
),
},
"detect_trailing_newline_has_newline_after_line_continuation": {
Expand All @@ -259,7 +259,7 @@ class TestDetectConfig(UnitTest):
default_newline="\n",
has_trailing_newline=True,
version=PythonVersionInfo(3, 7),
future_imports=set(),
future_imports=frozenset(),
),
},
"future_imports_in_correct_position": {
Expand All @@ -279,7 +279,7 @@ class TestDetectConfig(UnitTest):
default_newline="\n",
has_trailing_newline=True,
version=PythonVersionInfo(3, 7),
future_imports={"a"},
future_imports=frozenset({"a"}),
),
},
"future_imports_in_mixed_position": {
Expand All @@ -302,7 +302,7 @@ class TestDetectConfig(UnitTest):
default_newline="\n",
has_trailing_newline=True,
version=PythonVersionInfo(3, 7),
future_imports={"a", "b"},
future_imports=frozenset({"a", "b"}),
),
},
}
Expand Down
6 changes: 3 additions & 3 deletions libcst/_parser/types/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import re
from dataclasses import dataclass, field, fields
from enum import Enum
from typing import List, Pattern, Sequence, Set, Union
from typing import FrozenSet, List, Pattern, Sequence, Union

from libcst._add_slots import add_slots
from libcst._nodes.whitespace import NEWLINE_RE
Expand Down Expand Up @@ -45,7 +45,7 @@ class ParserConfig(BaseWhitespaceParserConfig):
default_newline: str
has_trailing_newline: bool
version: PythonVersionInfo
future_imports: Set[str]
future_imports: FrozenSet[str]


class AutoConfig(Enum):
Expand Down Expand Up @@ -98,7 +98,7 @@ class PartialParserConfig:
encoding: Union[str, AutoConfig] = AutoConfig.token

#: Detected ``__future__`` import names
future_imports: Union[Set[str], AutoConfig] = AutoConfig.token
future_imports: Union[FrozenSet[str], AutoConfig] = AutoConfig.token

#: The indentation of the file, expressed as a series of tabs and/or spaces. This
#: value is inferred from the contents of the parsed source code by default.
Expand Down
1 change: 1 addition & 0 deletions libcst/_parser/types/production.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class Production:
name: str
children: str
version: Optional[str]
future: Optional[str]

def __str__(self) -> str:
return f"{self.name}: {self.children}"

0 comments on commit 0c7d8b4

Please sign in to comment.