Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: rename MLIR-specific lexing constructs #3592

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions bench/parser/bench_lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
from collections.abc import Iterable

from xdsl.utils.lexer import Input
from xdsl.utils.mlir_lexer import Kind, Lexer
from xdsl.utils.mlir_lexer import MLIRLexer, MLIRTokenKind


def lex_file(file: Input):
"""
Lex the given file
"""
lexer = Lexer(file)
while lexer.lex().kind is not Kind.EOF:
lexer = MLIRLexer(file)
while lexer.lex().kind is not MLIRTokenKind.EOF:
pass


Expand Down
92 changes: 46 additions & 46 deletions tests/test_lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@

from xdsl.utils.exceptions import ParseError
from xdsl.utils.lexer import Input
from xdsl.utils.mlir_lexer import Kind, Lexer, Token
from xdsl.utils.mlir_lexer import MLIRLexer, MLIRToken, MLIRTokenKind


def get_token(input: str) -> Token:
def get_token(input: str) -> MLIRToken:
file = Input(input, "<unknown>")
lexer = Lexer(file)
lexer = MLIRLexer(file)
token = lexer.lex()
return token


def assert_single_token(
input: str, expected_kind: Kind, expected_text: str | None = None
input: str, expected_kind: MLIRTokenKind, expected_text: str | None = None
):
if expected_text is None:
expected_text = input
Expand All @@ -26,37 +26,37 @@ def assert_single_token(

def assert_token_fail(input: str):
file = Input(input, "<unknown>")
lexer = Lexer(file)
lexer = MLIRLexer(file)
with pytest.raises(ParseError):
lexer.lex()


@pytest.mark.parametrize(
"text,kind",
[
("->", Kind.ARROW),
(":", Kind.COLON),
(",", Kind.COMMA),
("...", Kind.ELLIPSIS),
("=", Kind.EQUAL),
(">", Kind.GREATER),
("{", Kind.L_BRACE),
("(", Kind.L_PAREN),
("[", Kind.L_SQUARE),
("<", Kind.LESS),
("-", Kind.MINUS),
("+", Kind.PLUS),
("?", Kind.QUESTION),
("}", Kind.R_BRACE),
(")", Kind.R_PAREN),
("]", Kind.R_SQUARE),
("*", Kind.STAR),
("|", Kind.VERTICAL_BAR),
("{-#", Kind.FILE_METADATA_BEGIN),
("#-}", Kind.FILE_METADATA_END),
("->", MLIRTokenKind.ARROW),
(":", MLIRTokenKind.COLON),
(",", MLIRTokenKind.COMMA),
("...", MLIRTokenKind.ELLIPSIS),
("=", MLIRTokenKind.EQUAL),
(">", MLIRTokenKind.GREATER),
("{", MLIRTokenKind.L_BRACE),
("(", MLIRTokenKind.L_PAREN),
("[", MLIRTokenKind.L_SQUARE),
("<", MLIRTokenKind.LESS),
("-", MLIRTokenKind.MINUS),
("+", MLIRTokenKind.PLUS),
("?", MLIRTokenKind.QUESTION),
("}", MLIRTokenKind.R_BRACE),
(")", MLIRTokenKind.R_PAREN),
("]", MLIRTokenKind.R_SQUARE),
("*", MLIRTokenKind.STAR),
("|", MLIRTokenKind.VERTICAL_BAR),
("{-#", MLIRTokenKind.FILE_METADATA_BEGIN),
("#-}", MLIRTokenKind.FILE_METADATA_END),
],
)
def test_punctuation(text: str, kind: Kind):
def test_punctuation(text: str, kind: MLIRTokenKind):
assert_single_token(text, kind)


Expand All @@ -69,7 +69,7 @@ def test_punctuation_fail(text: str):
"text", ['""', '"@"', '"foo"', '"\\""', '"\\n"', '"\\\\"', '"\\t"']
)
def test_str_literal(text: str):
assert_single_token(text, Kind.STRING_LIT)
assert_single_token(text, MLIRTokenKind.STRING_LIT)


@pytest.mark.parametrize("text", ['"', '"\\"', '"\\a"', '"\n"', '"\v"', '"\f"'])
Expand All @@ -82,7 +82,7 @@ def test_str_literal_fail(text: str):
)
def test_bare_ident(text: str):
"""bare-id ::= (letter|[_]) (letter|digit|[_$.])*"""
assert_single_token(text, Kind.BARE_IDENT)
assert_single_token(text, MLIRTokenKind.BARE_IDENT)


@pytest.mark.parametrize(
Expand All @@ -109,7 +109,7 @@ def test_bare_ident(text: str):
)
def test_at_ident(text: str):
"""at-ident ::= `@` (bare-id | string-literal)"""
assert_single_token(text, Kind.AT_IDENT)
assert_single_token(text, MLIRTokenKind.AT_IDENT)


@pytest.mark.parametrize(
Expand All @@ -129,10 +129,10 @@ def test_prefixed_ident(text: str):
"""percent-ident ::= `%` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*)"""
"""caret-ident ::= `^` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*)"""
"""exclamation-ident ::= `!` (digit+ | (letter|[$._-]) (letter|[$._-]|digit)*)"""
assert_single_token("#" + text, Kind.HASH_IDENT)
assert_single_token("%" + text, Kind.PERCENT_IDENT)
assert_single_token("^" + text, Kind.CARET_IDENT)
assert_single_token("!" + text, Kind.EXCLAMATION_IDENT)
assert_single_token("#" + text, MLIRTokenKind.HASH_IDENT)
assert_single_token("%" + text, MLIRTokenKind.PERCENT_IDENT)
assert_single_token("^" + text, MLIRTokenKind.CARET_IDENT)
assert_single_token("!" + text, MLIRTokenKind.EXCLAMATION_IDENT)


@pytest.mark.parametrize("text", ["+", '""', "#", "%", "^", "!", "\n", ""])
Expand All @@ -155,46 +155,46 @@ def test_prefixed_ident_fail(text: str):
)
def test_prefixed_ident_split(text: str, expected: str):
"""Check that the prefixed identifier is split at the right character."""
assert_single_token("#" + text, Kind.HASH_IDENT, "#" + expected)
assert_single_token("%" + text, Kind.PERCENT_IDENT, "%" + expected)
assert_single_token("^" + text, Kind.CARET_IDENT, "^" + expected)
assert_single_token("!" + text, Kind.EXCLAMATION_IDENT, "!" + expected)
assert_single_token("#" + text, MLIRTokenKind.HASH_IDENT, "#" + expected)
assert_single_token("%" + text, MLIRTokenKind.PERCENT_IDENT, "%" + expected)
assert_single_token("^" + text, MLIRTokenKind.CARET_IDENT, "^" + expected)
assert_single_token("!" + text, MLIRTokenKind.EXCLAMATION_IDENT, "!" + expected)


@pytest.mark.parametrize("text", ["0", "01", "123456789", "99", "0x1234", "0xabcdef"])
def test_integer_literal(text: str):
assert_single_token(text, Kind.INTEGER_LIT)
assert_single_token(text, MLIRTokenKind.INTEGER_LIT)


@pytest.mark.parametrize(
"text,expected", [("0a", "0"), ("0xg", "0"), ("0xfg", "0xf"), ("0xf.", "0xf")]
)
def test_integer_literal_split(text: str, expected: str):
assert_single_token(text, Kind.INTEGER_LIT, expected)
assert_single_token(text, MLIRTokenKind.INTEGER_LIT, expected)


@pytest.mark.parametrize(
"text", ["0.", "1.", "0.2", "38.1243", "92.54e43", "92.5E43", "43.3e-54", "32.E+25"]
)
def test_float_literal(text: str):
assert_single_token(text, Kind.FLOAT_LIT)
assert_single_token(text, MLIRTokenKind.FLOAT_LIT)


@pytest.mark.parametrize(
"text,expected", [("3.9e", "3.9"), ("4.5e+", "4.5"), ("5.8e-", "5.8")]
)
def test_float_literal_split(text: str, expected: str):
assert_single_token(text, Kind.FLOAT_LIT, expected)
assert_single_token(text, MLIRTokenKind.FLOAT_LIT, expected)


@pytest.mark.parametrize("text", ["0", " 0", " 0", "\n0", "\t0", "// Comment\n0"])
def test_whitespace_skip(text: str):
assert_single_token(text, Kind.INTEGER_LIT, "0")
assert_single_token(text, MLIRTokenKind.INTEGER_LIT, "0")


@pytest.mark.parametrize("text", ["", " ", "\n\n", "// Comment\n"])
def test_eof(text: str):
assert_single_token(text, Kind.EOF, "")
assert_single_token(text, MLIRTokenKind.EOF, "")


@pytest.mark.parametrize(
Expand All @@ -209,7 +209,7 @@ def test_eof(text: str):
)
def test_token_get_int_value(text: str, expected: int):
token = get_token(text)
assert token.kind == Kind.INTEGER_LIT
assert token.kind == MLIRTokenKind.INTEGER_LIT
assert token.kind.get_int_value(token.span) == expected


Expand All @@ -228,7 +228,7 @@ def test_token_get_int_value(text: str, expected: int):
)
def test_token_get_float_value(text: str, expected: float):
token = get_token(text)
assert token.kind == Kind.FLOAT_LIT
assert token.kind == MLIRTokenKind.FLOAT_LIT
assert token.kind.get_float_value(token.span) == expected


Expand All @@ -246,5 +246,5 @@ def test_token_get_float_value(text: str, expected: float):
)
def test_token_get_string_literal_value(text: str, expected: float):
token = get_token(text)
assert token.kind == Kind.STRING_LIT
assert token.kind == MLIRTokenKind.STRING_LIT
assert token.kind.get_string_literal_value(token.span) == expected
37 changes: 20 additions & 17 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from xdsl.parser import Parser
from xdsl.printer import Printer
from xdsl.utils.exceptions import ParseError, VerifyException
from xdsl.utils.mlir_lexer import Kind, PunctuationSpelling
from xdsl.utils.mlir_lexer import MLIRTokenKind, PunctuationSpelling
from xdsl.utils.str_enum import StrEnum

# pyright: reportPrivateUsage=false
Expand Down Expand Up @@ -584,51 +584,54 @@ def test_parse_comma_separated_list_error_delimiters(


@pytest.mark.parametrize(
"punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().values())
"punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().values())
)
def test_is_punctuation_true(punctuation: Kind):
def test_is_punctuation_true(punctuation: MLIRTokenKind):
assert punctuation.is_punctuation()


@pytest.mark.parametrize("punctuation", [Kind.BARE_IDENT, Kind.EOF, Kind.INTEGER_LIT])
def test_is_punctuation_false(punctuation: Kind):
@pytest.mark.parametrize(
"punctuation",
[MLIRTokenKind.BARE_IDENT, MLIRTokenKind.EOF, MLIRTokenKind.INTEGER_LIT],
)
def test_is_punctuation_false(punctuation: MLIRTokenKind):
assert not punctuation.is_punctuation()


@pytest.mark.parametrize(
"punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().values())
"punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().values())
)
def test_is_spelling_of_punctuation_true(punctuation: Kind):
def test_is_spelling_of_punctuation_true(punctuation: MLIRTokenKind):
value = cast(PunctuationSpelling, punctuation.value)
assert Kind.is_spelling_of_punctuation(value)
assert MLIRTokenKind.is_spelling_of_punctuation(value)


@pytest.mark.parametrize("punctuation", [">-", "o", "4", "$", "_", "@"])
def test_is_spelling_of_punctuation_false(punctuation: str):
assert not Kind.is_spelling_of_punctuation(punctuation)
assert not MLIRTokenKind.is_spelling_of_punctuation(punctuation)


@pytest.mark.parametrize(
"punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().values())
"punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().values())
)
def test_get_punctuation_kind(punctuation: Kind):
def test_get_punctuation_kind(punctuation: MLIRTokenKind):
value = cast(PunctuationSpelling, punctuation.value)
assert punctuation.get_punctuation_kind_from_spelling(value) == punctuation


@pytest.mark.parametrize(
"punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().keys())
"punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().keys())
)
def test_parse_punctuation(punctuation: PunctuationSpelling):
parser = Parser(MLContext(), punctuation)

res = parser.parse_punctuation(punctuation)
assert res == punctuation
assert parser._parse_token(Kind.EOF, "").kind == Kind.EOF
assert parser._parse_token(MLIRTokenKind.EOF, "").kind == MLIRTokenKind.EOF


@pytest.mark.parametrize(
"punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().keys())
"punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().keys())
)
def test_parse_punctuation_fail(punctuation: PunctuationSpelling):
parser = Parser(MLContext(), "e +")
Expand All @@ -639,17 +642,17 @@ def test_parse_punctuation_fail(punctuation: PunctuationSpelling):


@pytest.mark.parametrize(
"punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().keys())
"punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().keys())
)
def test_parse_optional_punctuation(punctuation: PunctuationSpelling):
parser = Parser(MLContext(), punctuation)
res = parser.parse_optional_punctuation(punctuation)
assert res == punctuation
assert parser._parse_token(Kind.EOF, "").kind == Kind.EOF
assert parser._parse_token(MLIRTokenKind.EOF, "").kind == MLIRTokenKind.EOF


@pytest.mark.parametrize(
"punctuation", list(Kind.get_punctuation_spelling_to_kind_dict().keys())
"punctuation", list(MLIRTokenKind.get_punctuation_spelling_to_kind_dict().keys())
)
def test_parse_optional_punctuation_fail(punctuation: PunctuationSpelling):
parser = Parser(MLContext(), "e +")
Expand Down
4 changes: 2 additions & 2 deletions xdsl/ir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from xdsl.traits import IsTerminator, NoTerminator, OpTrait, OpTraitInvT
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.mlir_lexer import Lexer
from xdsl.utils.mlir_lexer import MLIRLexer
from xdsl.utils.str_enum import StrEnum

# Used for cyclic dependencies in type hints
Expand Down Expand Up @@ -404,7 +404,7 @@ def _check_enum_constraints(
raise TypeError("Only direct inheritance from EnumAttribute is allowed.")

for v in enum_type:
if Lexer.bare_identifier_suffix_regex.fullmatch(v) is None:
if MLIRLexer.bare_identifier_suffix_regex.fullmatch(v) is None:
raise ValueError(
"All StrEnum values of an EnumAttribute must be parsable as an identifer."
)
Expand Down
14 changes: 7 additions & 7 deletions xdsl/irdl/declarative_assembly_format_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,19 @@
)
from xdsl.parser import BaseParser, ParserState
from xdsl.utils.lexer import Input
from xdsl.utils.mlir_lexer import Kind, Lexer, Token
from xdsl.utils.mlir_lexer import MLIRLexer, MLIRToken, MLIRTokenKind


@dataclass
class FormatLexer(Lexer):
class FormatLexer(MLIRLexer):
"""
A lexer for the declarative assembly format.
The differences with the MLIR lexer are the following:
* It can parse '`' or '$' as tokens. The token will have the `BARE_IDENT` kind.
* Bare identifiers may also may contain `-`.
"""

def lex(self) -> Token:
def lex(self) -> MLIRToken:
"""Lex a token from the input, and returns it."""
# First, skip whitespaces
self._consume_whitespace()
Expand All @@ -102,13 +102,13 @@ def lex(self) -> Token:

# Handle end of file
if current_char is None:
return self._form_token(Kind.EOF, start_pos)
return self._form_token(MLIRTokenKind.EOF, start_pos)

# We parse '`', `\\` and '$' as a BARE_IDENT.
# This is a hack to reuse the MLIR lexer.
if current_char in ("`", "$", "\\", "^"):
self._consume_chars()
return self._form_token(Kind.BARE_IDENT, start_pos)
return self._form_token(MLIRTokenKind.BARE_IDENT, start_pos)
return super().lex()

# Authorize `-` in bare identifier
Expand Down Expand Up @@ -168,7 +168,7 @@ def parse_format(self) -> FormatProgram:
unambiguous and refer to all elements exactly once.
"""
elements: list[FormatDirective] = []
while self._current_token.kind != Kind.EOF:
while self._current_token.kind != MLIRTokenKind.EOF:
elements.append(self.parse_format_directive())

self.add_reserved_attrs_to_directive(elements)
Expand Down Expand Up @@ -717,7 +717,7 @@ def parse_keyword_or_punctuation(self) -> FormatDirective:
if self._current_token.kind.is_punctuation():
punctuation = self._consume_token().text
self.parse_characters("`")
assert Kind.is_spelling_of_punctuation(punctuation)
assert MLIRTokenKind.is_spelling_of_punctuation(punctuation)
return PunctuationDirective(punctuation)

# Identifier case
Expand Down
Loading
Loading