Skip to content

Commit

Permalink
Merge pull request #114 from 15r10nk/fix_dirty_equals
Browse files Browse the repository at this point in the history
fix: dirty_equals can be compared multiple times
  • Loading branch information
15r10nk authored Sep 24, 2024
2 parents 84c496c + 26ec5e8 commit 2e3a182
Show file tree
Hide file tree
Showing 8 changed files with 397 additions and 59 deletions.
49 changes: 49 additions & 0 deletions changelog.d/20240917_192956_15r10nk-git_fix_dirty_equals.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Removed
- A bullet item for the Removed category.
-->
<!--
### Added
- A bullet item for the Added category.
-->
### Changed

- star-expressions in list or dicts where never valid and cause a warning now.
```
other=[2]
assert [5,2]==snapshot([5,*other])
```

<!--
### Deprecated
- A bullet item for the Deprecated category.
-->
### Fixed

- A snapshot which contains an dirty-equals expression can now be compared multiple times.

``` python
def test_something():
greeting = "hello"
for name in ["alex", "bob"]:
assert (name, greeting) == snapshot((IsString(), "hello"))
```

<!--
### Security

- A bullet item for the Security category.

-->
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,8 @@ exclude = "tests/.*_samples"
[tool.pyright]
venv = "test-3-12"
venvPath = ".nox"


[tool.scriv]
format = "md"
version = "literal: pyproject.toml: project.version"
164 changes: 127 additions & 37 deletions src/inline_snapshot/_inline_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import inspect
import tokenize
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Any
Expand All @@ -15,7 +16,6 @@

from ._align import add_x
from ._align import align
from ._change import apply_all
from ._change import CallArg
from ._change import Change
from ._change import Delete
Expand Down Expand Up @@ -46,6 +46,10 @@ class NotImplementedYet(Exception):
_missing_values = 0


class InlineSnapshotSyntaxWarning(Warning):
pass


class Flags:
"""
fix: the value needs to be changed to pass the tests
Expand Down Expand Up @@ -169,27 +173,51 @@ def _new_code(self):
assert False

def _get_changes(self) -> Iterator[Change]:
# generic fallback
new_token = value_to_token(self._old_value)

if (
self._ast_node is not None
and self._token_of_node(self._ast_node) != new_token
):
flag = "update"
else:
return
def handle(node, obj):
if isinstance(obj, list):
if not isinstance(node, ast.List):
return
for node_value, value in zip(node.elts, obj):
yield from handle(node_value, value)
elif isinstance(obj, tuple):
if not isinstance(node, ast.Tuple):
return
for node_value, value in zip(node.elts, obj):
yield from handle(node_value, value)

elif isinstance(obj, dict):
if not isinstance(node, ast.Dict):
return
for value_key, node_key, node_value in zip(
obj.keys(), node.keys, node.values
):
try:
# this is just a sanity check, dicts should be ordered
node_key = ast.literal_eval(node_key)
except Exception:
pass
else:
assert node_key == value_key

new_code = self._token_to_code(new_token)
yield from handle(node_value, obj[value_key])
else:
if update_allowed(obj):
new_token = value_to_token(obj)
if self._token_of_node(node) != new_token:
new_code = self._token_to_code(new_token)

yield Replace(
node=self._ast_node,
source=self._source,
new_code=new_code,
flag="update",
old_value=self._old_value,
new_value=self._old_value,
)

yield Replace(
node=self._ast_node,
source=self._source,
new_code=new_code,
flag=flag,
old_value=self._old_value,
new_value=self._old_value,
)
if self._source is not None:
yield from handle(self._ast_node, self._old_value)

# functions which determine the type

Expand Down Expand Up @@ -252,8 +280,57 @@ def __eq__(self, other):
if self._old_value is undefined:
_missing_values += 1

def use_valid_old_values(old_value, new_value):

if isinstance(new_value, dirty_equals.DirtyEquals):
assert False

if (
isinstance(new_value, list)
and isinstance(old_value, list)
or isinstance(new_value, tuple)
and isinstance(old_value, tuple)
):
diff = add_x(align(old_value, new_value))
old = iter(old_value)
new = iter(new_value)
result = []
for c in diff:
if c in "mx":
old_value_element = next(old)
new_value_element = next(new)
result.append(
use_valid_old_values(old_value_element, new_value_element)
)
elif c == "i":
result.append(next(new))
elif c == "d":
pass
else:
assert False

return type(new_value)(result)

elif isinstance(new_value, dict) and isinstance(old_value, dict):
result = {}

for key, new_value_element in new_value.items():
if key in old_value:
result[key] = use_valid_old_values(
old_value[key], new_value_element
)
else:
result[key] = new_value_element

return result

if new_value == old_value:
return old_value
else:
return new_value

if self._new_value is undefined:
self._new_value = clone(other)
self._new_value = use_valid_old_values(self._old_value, clone(other))

return self._visible_value() == other

Expand All @@ -274,6 +351,15 @@ def check(old_value, old_node, new_value):
and isinstance(new_value, tuple)
and isinstance(old_value, tuple)
):
for e in old_node.elts:
if isinstance(e, ast.Starred):
warnings.warn_explicit(
"star-expressions are not supported inside snapshots",
filename=self._source.filename,
lineno=e.lineno,
category=InlineSnapshotSyntaxWarning,
)
return
diff = add_x(align(old_value, new_value))
old = zip(old_value, old_node.elts)
new = iter(new_value)
Expand Down Expand Up @@ -313,6 +399,17 @@ def check(old_value, old_node, new_value):
and isinstance(old_value, dict)
and len(old_value) == len(old_node.keys)
):

for key, value in zip(old_node.keys, old_node.values):
if key is None:
warnings.warn_explicit(
"star-expressions are not supported inside snapshots",
filename=self._source.filename,
lineno=value.lineno,
category=InlineSnapshotSyntaxWarning,
)
return

for value, node in zip(old_value.keys(), old_node.keys):
assert node is not None

Expand Down Expand Up @@ -371,14 +468,22 @@ def check(old_value, old_node, new_value):
return

# generic fallback
new_token = value_to_token(new_value)

# because IsStr() != IsStr()
if type(old_value) is type(new_value) and not update_allowed(new_value):
return

if old_node is None:
new_token = []
else:
new_token = value_to_token(new_value)

if not old_value == new_value:
flag = "fix"
elif (
self._ast_node is not None
and self._token_of_node(old_node) != new_token
and update_allowed(old_value)
and self._token_of_node(old_node) != new_token
):
flag = "update"
else:
Expand Down Expand Up @@ -751,18 +856,3 @@ def _changes(self):
else:

yield from self._value._get_changes()

def _change(self):
changes = list(self._changes())
apply_all(
[change for change in changes if change.flag in _update_flags.to_set()]
)

@property
def _flags(self):

if self._value._old_value is undefined:
return {"create"}

changes = self._value._get_changes()
return {change.flag for change in changes}
3 changes: 0 additions & 3 deletions src/inline_snapshot/_rewrite_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,6 @@ def files(self) -> Iterable[SourceFile]:
def new_change(self):
return Change(self)

def changes(self):
return list(self._changes)

def num_fixes(self):
changes = set()
for file in self._source_files.values():
Expand Down
2 changes: 1 addition & 1 deletion src/inline_snapshot/extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_prints():
=== "ignore stdout"
<!-- inline-snapshot: outcome-passed=1 -->
``` python hl_lines="3 9"
``` python hl_lines="3 9 10"
from inline_snapshot import snapshot
from inline_snapshot.extra import prints
from dirty_equals import IsStr
Expand Down
28 changes: 19 additions & 9 deletions src/inline_snapshot/testing/_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@

import inline_snapshot._external
import inline_snapshot._external as external
from inline_snapshot import _inline_snapshot
from inline_snapshot._inline_snapshot import Flags
from inline_snapshot._rewrite_code import ChangeRecorder
from inline_snapshot._types import Category
from inline_snapshot._types import Snapshot

from .. import _inline_snapshot
from .._change import apply_all
from .._inline_snapshot import Flags
from .._rewrite_code import ChangeRecorder
from .._types import Category
from .._types import Snapshot


@contextlib.contextmanager
Expand Down Expand Up @@ -160,11 +162,19 @@ def run_inline(
finally:
_inline_snapshot._active = False

snapshot_flags = set()

changes = []
for snapshot in _inline_snapshot.snapshots.values():
snapshot_flags |= snapshot._flags
snapshot._change()
changes += snapshot._changes()

snapshot_flags = {change.flag for change in changes}

apply_all(
[
change
for change in changes
if change.flag in _inline_snapshot._update_flags.to_set()
]
)

if reported_categories is not None:
assert sorted(snapshot_flags) == reported_categories
Expand Down
17 changes: 12 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import inline_snapshot._external
import pytest
from inline_snapshot import _inline_snapshot
from inline_snapshot._change import apply_all
from inline_snapshot._format import format_code
from inline_snapshot._inline_snapshot import Flags
from inline_snapshot._rewrite_code import ChangeRecorder
Expand Down Expand Up @@ -106,13 +107,19 @@ def run(self, *flags):

number_snapshots = len(_inline_snapshot.snapshots)

snapshot_flags = set()

changes = []
for snapshot in _inline_snapshot.snapshots.values():
snapshot_flags |= snapshot._flags
snapshot._change()
changes += snapshot._changes()

snapshot_flags = {change.flag for change in changes}

changes = recorder.changes()
apply_all(
[
change
for change in changes
if change.flag in _inline_snapshot._update_flags.to_set()
]
)

recorder.fix_all()

Expand Down
Loading

0 comments on commit 2e3a182

Please sign in to comment.