Skip to content

Commit

Permalink
feat: handle all unmanaged values equal
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed Oct 30, 2024
1 parent 9f06f9a commit 0f53a01
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 11 deletions.
8 changes: 8 additions & 0 deletions src/inline_snapshot/_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,11 @@ def get_adapter(self, old_value, new_value) -> Adapter:

def assign(self, old_value, old_node, new_value):
raise NotImplementedError

@classmethod
def map(cls, value, map_function):
raise NotImplementedError(cls)


def adapter_map(value, map_function):
return get_adapter_type(value).map(value, map_function)
12 changes: 11 additions & 1 deletion src/inline_snapshot/_adapter/dataclass_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,20 @@
from .._change import Delete
from ..syntax_warnings import InlineSnapshotSyntaxWarning
from .adapter import Adapter
from .adapter import adapter_map
from .adapter import Item


class DataclassAdapter(Adapter):

@classmethod
def map(cls, value, map_function):
new_args, new_kwargs = cls.arguments(value)
return type(value)(
*[adapter_map(arg, map_function) for arg in new_args],
**{k: adapter_map(kwarg, map_function) for k, kwarg in new_kwargs.items()},
)

def items(self, value, node):
assert isinstance(node, ast.Call)
assert not node.args
Expand All @@ -27,7 +36,8 @@ def items(self, value, node):
if kw.arg
]

def arguments(self, value):
@classmethod
def arguments(cls, value):

kwargs = {}

Expand Down
6 changes: 6 additions & 0 deletions src/inline_snapshot/_adapter/dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,16 @@
from .._change import DictInsert
from ..syntax_warnings import InlineSnapshotSyntaxWarning
from .adapter import Adapter
from .adapter import adapter_map
from .adapter import Item


class DictAdapter(Adapter):

@classmethod
def map(cls, value, map_function):
return {k: adapter_map(v, map_function) for k, v in value.items()}

def items(self, value, node):
if node is None:
return [Item(value=value, node=None) for value in value.values()]
Expand Down
6 changes: 6 additions & 0 deletions src/inline_snapshot/_adapter/sequence_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@
from .._compare_context import compare_context
from ..syntax_warnings import InlineSnapshotSyntaxWarning
from .adapter import Adapter
from .adapter import adapter_map
from .adapter import Item


class SequenceAdapter(Adapter):
node_type: type
value_type: type

@classmethod
def map(cls, value, map_function):
result = [adapter_map(v, map_function) for v in value]
return cls.value_type(result)

def items(self, value, node):

assert isinstance(node, self.node_type), (node, self)
Expand Down
9 changes: 6 additions & 3 deletions src/inline_snapshot/_adapter/value_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from inline_snapshot._unmanaged import is_unmanaged
from inline_snapshot._unmanaged import Unmanaged
from inline_snapshot._unmanaged import update_allowed
from inline_snapshot._utils import value_to_token

Expand All @@ -10,12 +10,15 @@

class ValueAdapter(Adapter):

@classmethod
def map(cls, value, map_function):
return map_function(value)

def assign(self, old_value, old_node, new_value):
# generic fallback

# because IsStr() != IsStr()

if is_unmanaged(old_value):
if isinstance(old_value, Unmanaged):
return old_value

if old_node is None:
Expand Down
24 changes: 17 additions & 7 deletions src/inline_snapshot/_inline_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from executing import Source
from inline_snapshot._adapter.adapter import Adapter
from inline_snapshot._adapter.adapter import adapter_map
from inline_snapshot._source_file import SourceFile

from ._adapter import get_adapter_type
Expand All @@ -23,10 +24,11 @@
from ._code_repr import code_repr
from ._compare_context import compare_only
from ._exceptions import UsageError
from ._is import Is
from ._sentinels import undefined
from ._types import Category
from ._types import Snapshot
from ._unmanaged import map_unmanaged
from ._unmanaged import Unmanaged
from ._unmanaged import update_allowed
from ._utils import value_to_token

Expand Down Expand Up @@ -89,6 +91,10 @@ def get_adapter(self, value):
def _re_eval(self, value):

def re_eval(old_value, node, value):
if isinstance(old_value, Unmanaged):
old_value.value = value
return

assert type(old_value) is type(value)

adapter = self.get_adapter(old_value)
Expand All @@ -100,14 +106,16 @@ def re_eval(old_value, node, value):
for old_item, new_item in zip(old_items, new_items):
re_eval(old_item.value, old_item.node, new_item.value)

elif isinstance(old_value, Is):
old_value.value = value.value

else:
if update_allowed(old_value):
assert old_value == value
if not old_value == value:
raise UsageError(
"snapshot value should not change. Use Is(...) for dynamic snapshot parts."
)
else:
assert not update_allowed(value)
assert (
False
), "old_value should already have been converted to Unmanaged"

re_eval(self._old_value, self._ast_node, value)

Expand Down Expand Up @@ -163,6 +171,8 @@ def __getitem__(self, _item):

class UndecidedValue(GenericValue):
def __init__(self, old_value, ast_node, source):

old_value = adapter_map(old_value, map_unmanaged)
self._old_value = old_value
self._new_value = undefined
self._ast_node = ast_node
Expand All @@ -184,7 +194,7 @@ def handle(node, obj):
yield from handle(item.node, item.value)
return

if update_allowed(obj):
if not isinstance(obj, Unmanaged):
new_token = value_to_token(obj)
if self._file._token_of_node(node) != new_token:
new_code = self._file._token_to_code(new_token)
Expand Down
17 changes: 17 additions & 0 deletions src/inline_snapshot/_unmanaged.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,20 @@ def update_allowed(value):

def is_unmanaged(value):
return not update_allowed(value)


class Unmanaged:
def __init__(self, value):
self.value = value

def __eq__(self, other):
assert not isinstance(other, Unmanaged)

return self.value == other


def map_unmanaged(value):
if is_unmanaged(value):
return Unmanaged(value)
else:
return value
32 changes: 32 additions & 0 deletions tests/test_dirty_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,35 @@ def equals(self, other):
"""
),
)


def test_dirty_equals_with_changing_args() -> None:

Example(
"""\
from dirty_equals import IsInt
from inline_snapshot import snapshot
def test_number():
for i in range(5):
assert ["a",i] == snapshot(["e",IsInt(gt=i-1,lt=i+1)])
"""
).run_inline(
["--inline-snapshot=fix"],
changed_files=snapshot(
{
"test_something.py": """\
from dirty_equals import IsInt
from inline_snapshot import snapshot
def test_number():
for i in range(5):
assert ["a",i] == snapshot(["a",IsInt(gt=i-1,lt=i+1)])
"""
}
),
)
22 changes: 22 additions & 0 deletions tests/test_is.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from inline_snapshot._inline_snapshot import snapshot
from inline_snapshot.testing._example import Example


def test_missing_is():

Example(
"""\
from inline_snapshot import snapshot
def test_is():
for i in (1,2):
assert i == snapshot(i)
"""
).run_inline(
raises=snapshot(
"""\
UsageError:
snapshot value should not change. Use Is(...) for dynamic snapshot parts.\
"""
)
)

0 comments on commit 0f53a01

Please sign in to comment.