Skip to content

Fix atom-structure pickle #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 39 additions & 21 deletions src/diffpy/structure/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
"""This module defines class Structure.
"""

import collections
import copy
import copy as copymod
import numpy
import codecs
import six

from diffpy.structure.lattice import Lattice
from diffpy.structure.atom import Atom
from diffpy.structure.utils import _linkAtomAttribute, atomBareSymbol
from diffpy.structure.utils import isiterable

# ----------------------------------------------------------------------------

Expand Down Expand Up @@ -95,9 +95,9 @@ def __init__(self, atoms=None, lattice=None, title=None,


def copy(self):
'''Return a deep copy of this Structure object.
'''Return a copy of this Structure object.
'''
return copy.copy(self)
return copymod.copy(self)


def __copy__(self, target=None):
Expand All @@ -116,7 +116,7 @@ def __copy__(self, target=None):
# copy attributes as appropriate:
target.title = self.title
target.lattice = Lattice(self.lattice)
target.pdffit = copy.deepcopy(self.pdffit)
target.pdffit = copymod.deepcopy(self.pdffit)
# copy all atoms to the target
target[:] = self
return target
Expand Down Expand Up @@ -332,25 +332,43 @@ def insert(self, idx, a, copy=True):

No return value.
"""
adup = copy and Atom(a) or a
adup = copy and copymod.copy(a) or a
adup.lattice = self.lattice
super(Structure, self).insert(idx, adup)
return


def extend(self, atoms, copy=True):
"""Extend Structure by appending copies from a list of atoms.
def extend(self, atoms, copy=None):
"""Extend Structure with an iterable of atoms.

atoms -- list of Atom instances
copy -- flag for extending with copies of Atom instances.
When False extend with atoms and update their lattice
attributes.
Update the `lattice` attribute of all added atoms.

No return value.
Parameters
----------
atoms : iterable
The `Atom` objects to be appended to this Structure.
copy : bool, optional
Flag for adding copies of Atom objects.
Make copies when `True`, append `atoms` unchanged when ``False``.
The default behavior is to make copies when `atoms` are of
`Structure` type or if new atoms introduce repeated objects.
"""
adups = map(Atom, atoms) if copy else atoms
adups = (copymod.copy(a) for a in atoms)
if copy is None:
if isinstance(atoms, Structure):
newatoms = adups
else:
memo = set(id(a) for a in self)
nextatom = lambda a: (a if id(a) not in memo
else copymod.copy(a))
mark = lambda a: (memo.add(id(a)), a)[-1]
newatoms = (mark(nextatom(a)) for a in atoms)
elif copy:
newatoms = adups
else:
newatoms = atoms
setlat = lambda a: (setattr(a, 'lattice', self.lattice), a)[-1]
super(Structure, self).extend(setlat(a) for a in adups)
super(Structure, self).extend(setlat(a) for a in newatoms)
return


Expand Down Expand Up @@ -388,7 +406,7 @@ def __getitem__(self, idx):
# check if there is any string label that should be resolved
scalarstringlabel = isinstance(idx, six.string_types)
hasstringlabel = scalarstringlabel or (
isinstance(idx, collections.Iterable) and
isiterable(idx) and
any(isinstance(ii, six.string_types) for ii in idx))
# if not, use numpy indexing to resolve idx
if not hasstringlabel:
Expand Down Expand Up @@ -464,7 +482,7 @@ def __add__(self, other):

Return new Structure with a copy of Atom instances.
'''
rv = copy.copy(self)
rv = copymod.copy(self)
rv += other
return rv

Expand All @@ -476,7 +494,7 @@ def __iadd__(self, other):

Return self.
'''
self.extend(other)
self.extend(other, copy=True)
return self


Expand All @@ -489,7 +507,7 @@ def __sub__(self, other):
'''
otherset = set(other)
keepindices = [i for i, a in enumerate(self) if not a in otherset]
rv = copy.copy(self[keepindices])
rv = copymod.copy(self[keepindices])
return rv


Expand All @@ -513,7 +531,7 @@ def __mul__(self, n):

Return new Structure.
'''
rv = copy.copy(self[:0])
rv = copymod.copy(self[:0])
rv += n * self.tolist()
return rv

Expand All @@ -533,7 +551,7 @@ def __imul__(self, n):
if n <= 0:
self[:] = []
else:
self.extend((n - 1) * self.tolist())
self.extend((n - 1) * self.tolist(), copy=True)
return self

# Properties -------------------------------------------------------------
Expand Down
14 changes: 13 additions & 1 deletion src/diffpy/structure/tests/teststructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


import copy
import pickle
import unittest
import numpy

Expand Down Expand Up @@ -240,7 +241,7 @@ def test_extend(self):
self.assertEqual(6, len(stru))
self.assertTrue(all(a.lattice is stru.lattice for a in stru))
self.assertEqual(lst, stru.tolist()[:2])
self.assertNotEqual(stru[-1], cdse[-1])
self.assertFalse(stru[-1] is cdse[-1])
return


Expand Down Expand Up @@ -656,6 +657,17 @@ def test_Bij(self):
self.assertFalse(numpy.any(stru.U != 0.0))
return


def test_pickling(self):
"""Make sure Atom in Structure can be consistently pickled.
"""
stru = self.stru
a = stru[0]
self.assertTrue(a is stru[0])
a1, stru1 = pickle.loads(pickle.dumps((a, stru)))
self.assertTrue(a1 is stru1[0])
return

# End of class TestStructure

# ----------------------------------------------------------------------------
Expand Down
12 changes: 12 additions & 0 deletions src/diffpy/structure/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,20 @@
"""Small shared functions.
"""

import six
import numpy

if six.PY2:
from collections import Iterable as _Iterable
else:
from collections.abc import Iterable as _Iterable


def isiterable(obj):
"""True if argument is iterable."""
rv = isinstance(obj, _Iterable)
return rv


def isfloat(s):
"""True if argument can be converted to float"""
Expand Down