Skip to content

Commit

Permalink
refactor code generation tools , include --check command
Browse files Browse the repository at this point in the history
in particular it looks like CI was not picking up on the
"git diff" oriented commands, which were failing to run due
to pathing issues.  As we were setting cwd for black/zimports
relative to sqlalchemy library, and tox installs it in
the venv, black/zimports would fail to run from tox, and
since these are subprocess.run we didn't pick up the
failure.

This overall locks down how zimports/black are run
so that we are definitely from the source root, by using
the location of tools/ to determine the root.

Fixes: sqlalchemy#8892
Change-Id: I7c54b747edd5a80e0c699b8456febf66d8b62375
  • Loading branch information
zzzeek committed Jan 18, 2023
1 parent f91a25c commit cd96ffe
Show file tree
Hide file tree
Showing 10 changed files with 307 additions and 163 deletions.
12 changes: 12 additions & 0 deletions lib/sqlalchemy/orm/scoping.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,7 @@ def get(
with_for_update: Optional[ForUpdateArg] = None,
identity_token: Optional[Any] = None,
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
) -> Optional[_O]:
r"""Return an instance based on the given primary key identifier,
or ``None`` if not found.
Expand Down Expand Up @@ -975,6 +976,13 @@ def get(
:ref:`orm_queryguide_execution_options` - ORM-specific execution
options
:param bind_arguments: dictionary of additional arguments to determine
the bind. May include "mapper", "bind", or other custom arguments.
Contents of this dictionary are passed to the
:meth:`.Session.get_bind` method.
.. versionadded: 2.0.0rc1
:return: The object instance, or ``None``.
Expand All @@ -988,15 +996,18 @@ def get(
with_for_update=with_for_update,
identity_token=identity_token,
execution_options=execution_options,
bind_arguments=bind_arguments,
)

def get_bind(
self,
mapper: Optional[_EntityBindKey[_O]] = None,
*,
clause: Optional[ClauseElement] = None,
bind: Optional[_SessionBind] = None,
_sa_skip_events: Optional[bool] = None,
_sa_skip_for_implicit_returning: bool = False,
**kw: Any,
) -> Union[Engine, Connection]:
r"""Return a "bind" to which this :class:`.Session` is bound.
Expand Down Expand Up @@ -1082,6 +1093,7 @@ def get_bind(
bind=bind,
_sa_skip_events=_sa_skip_events,
_sa_skip_for_implicit_returning=_sa_skip_for_implicit_returning,
**kw,
)

def is_modified(
Expand Down
44 changes: 0 additions & 44 deletions lib/sqlalchemy/util/langhelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import inspect
import itertools
import operator
import os
import re
import sys
import textwrap
Expand All @@ -34,7 +33,6 @@
from typing import Iterator
from typing import List
from typing import Mapping
from typing import no_type_check
from typing import NoReturn
from typing import Optional
from typing import overload
Expand Down Expand Up @@ -2180,45 +2178,3 @@ def has_compiled_ext(raise_=False):
)
else:
return False


@no_type_check
def console_scripts(
path: str, options: dict, ignore_output: bool = False
) -> None:

import subprocess
import shlex
from pathlib import Path

is_posix = os.name == "posix"

entrypoint_name = options["entrypoint"]

for entry in compat.importlib_metadata_get("console_scripts"):
if entry.name == entrypoint_name:
impl = entry
break
else:
raise Exception(
f"Could not find entrypoint console_scripts.{entrypoint_name}"
)
cmdline_options_str = options.get("options", "")
cmdline_options_list = shlex.split(cmdline_options_str, posix=is_posix) + [
path
]

kw = {}
if ignore_output:
kw["stdout"] = kw["stderr"] = subprocess.DEVNULL

subprocess.run(
[
sys.executable,
"-c",
"import %s; %s.%s()" % (impl.module, impl.module, impl.attr),
]
+ cmdline_options_list,
cwd=Path(__file__).parent.parent,
**kw,
)
198 changes: 198 additions & 0 deletions lib/sqlalchemy/util/tool_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# util/tool_support.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: allow-untyped-defs, allow-untyped-calls
"""support routines for the helpers in tools/.
These aren't imported by the enclosing util package as the are not
needed for normal library use.
"""
from __future__ import annotations

from argparse import ArgumentParser
from argparse import Namespace
import contextlib
import difflib
import os
from pathlib import Path
import shlex
import shutil
import subprocess
import sys
from typing import Any
from typing import Dict
from typing import Iterator
from typing import Optional

from . import compat


class code_writer_cmd:
parser: ArgumentParser
args: Namespace
suppress_output: bool
diffs_detected: bool
source_root: Path
pyproject_toml_path: Path

def __init__(self, tool_script: str):
self.source_root = Path(tool_script).parent.parent
self.pyproject_toml_path = self.source_root / Path("pyproject.toml")
assert self.pyproject_toml_path.exists()

self.parser = ArgumentParser()
self.parser.add_argument(
"--stdout",
action="store_true",
help="Write to stdout instead of saving to file",
)
self.parser.add_argument(
"-c",
"--check",
help="Don't write the files back, just return the "
"status. Return code 0 means nothing would change. "
"Return code 1 means some files would be reformatted",
action="store_true",
)

def run_zimports(self, tempfile: str) -> None:
self._run_console_script(
str(tempfile),
{
"entrypoint": "zimports",
"options": f"--toml-config {self.pyproject_toml_path}",
},
)

def run_black(self, tempfile: str) -> None:
self._run_console_script(
str(tempfile),
{
"entrypoint": "black",
"options": f"--config {self.pyproject_toml_path}",
},
)

def _run_console_script(self, path: str, options: Dict[str, Any]) -> None:
"""Run a Python console application from within the process.
Used for black, zimports
"""

is_posix = os.name == "posix"

entrypoint_name = options["entrypoint"]

for entry in compat.importlib_metadata_get("console_scripts"):
if entry.name == entrypoint_name:
impl = entry
break
else:
raise Exception(
f"Could not find entrypoint console_scripts.{entrypoint_name}"
)
cmdline_options_str = options.get("options", "")
cmdline_options_list = shlex.split(
cmdline_options_str, posix=is_posix
) + [path]

kw: Dict[str, Any] = {}
if self.suppress_output:
kw["stdout"] = kw["stderr"] = subprocess.DEVNULL

subprocess.run(
[
sys.executable,
"-c",
"import %s; %s.%s()" % (impl.module, impl.module, impl.attr),
]
+ cmdline_options_list,
cwd=str(self.source_root),
**kw,
)

def write_status(self, *text: str) -> None:
if not self.suppress_output:
sys.stderr.write(" ".join(text))

def write_output_file_from_text(
self, text: str, destination_path: str
) -> None:
if self.args.check:
self._run_diff(destination_path, source=text)
elif self.args.stdout:
print(text)
else:
self.write_status(f"Writing {destination_path}...")
Path(destination_path).write_text(text)
self.write_status("done\n")

def write_output_file_from_tempfile(
self, tempfile: str, destination_path: str
) -> None:
if self.args.check:
self._run_diff(destination_path, source_file=tempfile)
os.unlink(tempfile)
elif self.args.stdout:
with open(tempfile) as tf:
print(tf.read())
os.unlink(tempfile)
else:
self.write_status(f"Writing {destination_path}...")
shutil.move(tempfile, destination_path)
self.write_status("done\n")

def _run_diff(
self,
destination_path: str,
*,
source: Optional[str] = None,
source_file: Optional[str] = None,
) -> None:
if source_file:
with open(source_file) as tf:
source_lines = list(tf)
elif source is not None:
source_lines = source.splitlines(keepends=True)
else:
assert False, "source or source_file is required"

with open(destination_path) as dp:
d = difflib.unified_diff(
list(dp),
source_lines,
fromfile=destination_path,
tofile="<proposed changes>",
n=3,
lineterm="\n",
)
d_as_list = list(d)
if d_as_list:
self.diffs_detected = True
print("".join(d_as_list))

@contextlib.contextmanager
def add_arguments(self) -> Iterator[ArgumentParser]:
yield self.parser

@contextlib.contextmanager
def run_program(self) -> Iterator[None]:
self.args = self.parser.parse_args()
if self.args.check:
self.diffs_detected = False
self.suppress_output = True
elif self.args.stdout:
self.suppress_output = True
else:
self.suppress_output = False
yield

if self.args.check and self.diffs_detected:
sys.exit(1)
else:
sys.exit(0)
1 change: 1 addition & 0 deletions test/orm/test_scoping.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def test_methods_etc(self):
with_for_update=None,
identity_token=None,
execution_options=util.EMPTY_DICT,
bind_arguments=None,
),
],
)
Expand Down
3 changes: 3 additions & 0 deletions tools/format_docs_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
that it extracts from the documentation.
.. versionadded:: 2.0
"""
# mypy: ignore-errors

from argparse import ArgumentParser
from argparse import RawDescriptionHelpFormatter
from collections.abc import Iterator
Expand Down
Loading

0 comments on commit cd96ffe

Please sign in to comment.