diff --git a/CHANGELOG.md b/CHANGELOG.md index 2047af8..ad8397e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # `edn_format` Changelog +## v0.5.14 (2018/08/22) + +* Fix vector parser to use ImmutableList + ## v0.5.13 (2017/10/08) * Convert requirements from exact to minimum diff --git a/edn_format/__init__.py b/edn_format/__init__.py index c15737e..e1e438f 100644 --- a/edn_format/__init__.py +++ b/edn_format/__init__.py @@ -7,8 +7,10 @@ from .edn_dump import dump as dumps from .exceptions import EDNDecodeError from .immutable_dict import ImmutableDict +from .immutable_list import ImmutableList __all__ = ( + 'ImmutableList', 'EDNDecodeError', 'ImmutableDict', 'Keyword', diff --git a/edn_format/edn_dump.py b/edn_format/edn_dump.py index a52d6c8..0f06153 100644 --- a/edn_format/edn_dump.py +++ b/edn_format/edn_dump.py @@ -11,6 +11,7 @@ import pyrfc3339 from .immutable_dict import ImmutableDict +from .immutable_list import ImmutableList from .edn_lex import Keyword, Symbol from .edn_parse import TaggedElement @@ -81,7 +82,7 @@ def udump(obj, return '{}M'.format(obj) elif isinstance(obj, (Keyword, Symbol)): return unicode(obj) - # CAVEAT EMPTOR! In Python 3 'basestring' is alised to 'str' above. + # CAVEAT LECTOR! In Python 3 'basestring' is alised to 'str' above. # Furthermore, in Python 2 bytes is an instance of 'str'/'basestring' while # in Python 3 it is not. elif isinstance(obj, bytes): @@ -90,7 +91,7 @@ def udump(obj, return unicode_escape(obj) elif isinstance(obj, tuple): return '({})'.format(seq(obj, **kwargs)) - elif isinstance(obj, list): + elif isinstance(obj, (list, ImmutableList)): return '[{}]'.format(seq(obj, **kwargs)) elif isinstance(obj, set) or isinstance(obj, frozenset): if sort_sets: diff --git a/edn_format/edn_parse.py b/edn_format/edn_parse.py index 006ddd7..c2be09d 100644 --- a/edn_format/edn_parse.py +++ b/edn_format/edn_parse.py @@ -11,6 +11,7 @@ from .edn_lex import tokens, lex from .exceptions import EDNDecodeError from .immutable_dict import ImmutableDict +from .immutable_list import ImmutableList if sys.version_info[0] == 3: @@ -57,12 +58,12 @@ def p_term_leaf(p): def p_empty_vector(p): """vector : VECTOR_START VECTOR_END""" - p[0] = [] + p[0] = ImmutableList([]) def p_vector(p): """vector : VECTOR_START expressions VECTOR_END""" - p[0] = p[2] + p[0] = ImmutableList(p[2]) def p_empty_list(p): diff --git a/edn_format/immutable_list.py b/edn_format/immutable_list.py new file mode 100644 index 0000000..5686847 --- /dev/null +++ b/edn_format/immutable_list.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, division, print_function, unicode_literals + +import collections +import copy as _copy + + +class ImmutableList(collections.Sequence, collections.Hashable): + def __init__(self, wrapped_list, copy=True): + """Returns an immutable version of the given list. Optionally creates a shallow copy.""" + self._list = _copy.copy(wrapped_list) if copy else wrapped_list + self._hash = None + + def __repr__(self): + return self._list.__repr__() + + def __eq__(self, other): + if isinstance(other, ImmutableList): + return self._list == other._list + else: + return self._list == other + + def _call_wrapped_list_method(self, method, *args): + new_list = _copy.copy(self._list) + getattr(new_list, method)(*args) + return ImmutableList(new_list, copy=False) + + # collection.Sequence methods + # https://docs.python.org/2/library/collections.html#collections-abstract-base-classes + + def __getitem__(self, index): + return self._list[index] + + def __len__(self): + return len(self._list) + + # collection.Hashable methods + # https://docs.python.org/2/library/collections.html#collections-abstract-base-classes + + def __hash__(self): + if self._hash is None: + self._hash = hash(tuple(self._list)) + return self._hash + + # Other list methods https://docs.python.org/2/tutorial/datastructures.html#more-on-lists + + def insert(self, *args): + return self._call_wrapped_list_method("insert", *args) + + def sort(self, *args): + return self._call_wrapped_list_method("sort", *args) diff --git a/setup.py b/setup.py index 77e9864..9a0c2fe 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ from distutils.core import setup setup(name="edn_format", - version="0.5.13", + version="0.5.14", author="Swaroop C H", author_email="swaroop@swaroopch.com", description="EDN format reader and writer in Python", diff --git a/tests.py b/tests.py index 938a181..8f7fcc1 100644 --- a/tests.py +++ b/tests.py @@ -11,7 +11,7 @@ import pytz from edn_format import edn_lex, edn_parse, \ - loads, dumps, Keyword, Symbol, TaggedElement, ImmutableDict, add_tag, \ + loads, dumps, Keyword, Symbol, TaggedElement, ImmutableDict, ImmutableList, add_tag, \ EDNDecodeError @@ -126,6 +126,9 @@ def test_parser(self): self.check_parse('\\', r'"\\"') self.check_parse(["abc", "123"], '["abc", "123"]') self.check_parse({"key": "value"}, '{"key" "value"}') + self.check_parse(frozenset({ImmutableList([u"ab", u"cd"]), + ImmutableList([u"ef"])}), + '#{["ab", "cd"], ["ef"]}') def check_roundtrip(self, data_input, **kw): self.assertEqual(data_input, loads(dumps(data_input, **kw))) @@ -354,5 +357,18 @@ def test_equality(self): self.assertTrue(Symbol("db/id") == Symbol("db/id")) +class ImmutableListTest(unittest.TestCase): + def test_list(self): + x = ImmutableList([1, 2, 3]) + self.assertTrue(x == [1, 2, 3]) + + self.assertTrue(x.index(1) == 0) + self.assertTrue(x.count(3) == 1) + self.assertTrue(x.insert(0, 0) == [0, 1, 2, 3]) + + y = ImmutableList([3, 1, 4]) + self.assertTrue(y.sort() == [1, 3, 4]) + + if __name__ == "__main__": unittest.main()