Skip to content

Commit

Permalink
Major modifications! First of all, significant speedups thanks to int…
Browse files Browse the repository at this point in the history
…ernal function restructuring and less unuseful queries.

Moreover, no validation is done on Nodes on __init__ (was very slow).
Moreover, much better management of reloads in Nodes. No more uuid=... as parameter, but you have to pass the dbnode (in this way, unuseful queries are avoided).
Moreover, all parameters of __init__ are automatically passed to corresponding set_...() methods, so no need anymore to redefine __init__ in subclasses (a few static class attributes also allow some fine tuning).
Another modification is that now the calc.set_resources method does not accept anymore **kwargs, but requires a dictionary.
Everything updated accordingly, tests run correctly and pw submits and parses without problems.
  • Loading branch information
giovannipizzi committed Mar 19, 2014
1 parent ca35c8d commit f1df00d
Show file tree
Hide file tree
Showing 34 changed files with 511 additions and 507 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
*~
*.project
*.pydevproject
.settings
.DS_Store
*/.DS_Store
*/*/.DS_Store
Expand Down
26 changes: 12 additions & 14 deletions aiida/common/folders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@

from aiida.common.utils import get_repository_folder

_sandbox_folder = os.path.realpath(os.path.join(get_repository_folder(),'sandbox'))
_perm_repository = os.path.realpath(os.path.join(get_repository_folder(),'repository'))

_valid_sections = ['node', 'workflow']

class Folder(object):
Expand All @@ -22,17 +19,17 @@ class Folder(object):
with os.pardir?)
"""
def __init__(self, abspath, folder_limit=None):
abspath = os.path.realpath(abspath)
abspath = os.path.abspath(abspath)
if folder_limit is None:
folder_limit = abspath
folder_limit = os.path.realpath(folder_limit)
folder_limit = os.path.abspath(folder_limit)

# check that it is a subfolder
if not os.path.commonprefix([abspath,folder_limit]) == folder_limit:
raise ValueError("The absolute path for this folder is not within the folder_limit. "
"abspath={}, folder_limit={}.".format(abspath, folder_limit))

self._abspath = os.path.realpath(abspath)
self._abspath = abspath
self._folder_limit = folder_limit


Expand All @@ -51,7 +48,7 @@ def get_subfolder(self, subfolder, create=False, reset_limit=False):
:Returns: a Folder object pointing to the subfolder.
"""
dest_abs_dir = os.path.realpath(os.path.join(
dest_abs_dir = os.path.abspath(os.path.join(
self.abspath,unicode(subfolder)))


Expand Down Expand Up @@ -362,10 +359,11 @@ def __init__(self):
sandbox.
"""
# First check if the sandbox folder already exists
if not os.path.exists(_sandbox_folder):
os.makedirs(_sandbox_folder)
sandbox = get_repository_folder('sandbox')
if not os.path.exists(sandbox):
os.makedirs(sandbox)

abspath = tempfile.mkdtemp(dir=_sandbox_folder)
abspath = tempfile.mkdtemp(dir=sandbox)
super(SandboxFolder, self).__init__(abspath=abspath)

def __enter__(self):
Expand All @@ -390,7 +388,7 @@ def __init__(self, section, uuid, subfolder=os.curdir):
Initializes the object by pointing it to a folder in the repository.
Pass the uuid as a string.
"""
"""
if section not in _valid_sections:
retstr = ("Repository section '{}' not allowed. "
"Valid sections are: {}".format(section, ",".join(_valid_sections)))
Expand All @@ -405,10 +403,10 @@ def __init__(self, section, uuid, subfolder=os.curdir):
# Note that a similar sharding should probably has to be done
# independently for calculations sent to remote computers in the
# execmanager.
entity_dir=os.path.realpath(os.path.join(
_perm_repository, unicode(section),
entity_dir=os.path.abspath(os.path.join(
get_repository_folder('repository'), unicode(section),
unicode(uuid)[:2], unicode(uuid)[2:4], unicode(uuid)[4:]))
dest = os.path.realpath(os.path.join(entity_dir,unicode(subfolder)))
dest = os.path.abspath(os.path.join(entity_dir,unicode(subfolder)))

# Internal variable of this class
self._subfolder=subfolder
Expand Down
26 changes: 24 additions & 2 deletions aiida/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,20 @@

CONFIG_FNAME = 'config.json'

class classproperty(object):
"""
A class that, when used as a decorator, works as if the
two decorators @property and @classmethod where applied together
(i.e., the object works as a property, both for the Class and for any
of its instance; and is called with the class cls rather than with the
instance as its first argument).
"""
def __init__(self, getter):
self.getter= getter
def __get__(self, instance, owner):
return self.getter(owner)


def backup_config():
import shutil
aiida_dir = os.path.expanduser("~/.aiida")
Expand Down Expand Up @@ -57,7 +71,7 @@ def get_new_uuid():
the_uuid = uuid.uuid4()
return force_unicode(the_uuid)

def get_repository_folder():
def get_repository_folder(subfolder=None):
"""
Return the top folder of the local repository.
"""
Expand All @@ -68,7 +82,15 @@ def get_repository_folder():
except ImportError:
raise ConfigurationError(
"The LOCAL_REPOSITORY variable is not set correctly.")
return os.path.realpath(LOCAL_REPOSITORY)
if subfolder is None:
return os.path.abspath(LOCAL_REPOSITORY)
elif subfolder == "sandbox":
return os.path.abspath(os.path.join(LOCAL_REPOSITORY,'sandbox'))
elif subfolder == "repository":
return os.path.abspath(os.path.join(LOCAL_REPOSITORY,'repository'))
else:
raise ValueError("Invalid 'subfolder' passed to "
"get_repository_folder: {}".format(subfolder))

def escape_for_bash(str_to_escape):
"""
Expand Down
34 changes: 2 additions & 32 deletions aiida/djsite/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def get_aiida_class(self):
"will use base Node class".format(self.type,self.pk))
PluginClass = Node

return PluginClass(uuid=self.uuid)
return PluginClass(dbnode=self)

def get_simple_name(self, invalid_result=None):
"""
Expand All @@ -150,29 +150,6 @@ def get_simple_name(self, invalid_result=None):
else:
thistype = thistype[:-1] # Strip final dot
return thistype.rpartition('.')[2]

def increment_version_number(self):
"""
This function increments the version number in the DB.
This should be called every time you need to increment the version
(e.g. on adding a extra or attribute).
:note: Do not manually increment the version number, because if
two different threads are adding/changing an attribute concurrently,
the version number would be incremented only once.
"""
from django.db.models import F

# I increment the node number using a filter
# (this should be the right way of doing it;
# dbnode.nodeversion = F('nodeversion') + 1
# will do weird stuff, returning Django Objects instead of numbers,
# and incrementing at every save; moreover in this way I should do
# the right thing for concurrent writings
# I use self._dbnode because this will not do a query to
# update the node; here I only need to get its pk
self.nodeversion = F('nodeversion') + 1
self.save()

@python_2_unicode_compatible
def __str__(self):
Expand Down Expand Up @@ -348,7 +325,7 @@ def get_all_values_for_node(cls, dbnode):


@classmethod
def set_value_for_node(cls, dbnode, key, value, incrementversion=True):
def set_value_for_node(cls, dbnode, key, value):
"""
This is the raw-level method that accesses the DB. No checks are done
to prevent the user from (re)setting a valid key.
Expand All @@ -360,10 +337,6 @@ def set_value_for_node(cls, dbnode, key, value, incrementversion=True):
:param dbnode: the dbnode for which the attribute should be stored
:param key: the key of the attribute to store
:param value: the value of the attribute to store
:param incrementversion : If incrementversion
is True (default), each attribute set will
udpate the version. This can be set to False during the store() so
that the version does not get increased for each attribute.
:raise ValueError: if the key contains the separator symbol used
internally to unpack dictionaries and lists (defined in cls._sep).
Expand All @@ -375,8 +348,6 @@ def set_value_for_node(cls, dbnode, key, value, incrementversion=True):
try:
sid = transaction.savepoint()

if incrementversion:
dbnode.increment_version_number()
attr, _ = cls.objects.get_or_create(dbnode=dbnode,
key=key)
attr.setvalue(value,with_transaction=False)
Expand All @@ -391,7 +362,6 @@ def del_value_for_node(cls, dbnode, key):
:raise AttributeError: if no key is found for the given dbnode
"""
dbnode.increment_version_number()
try:
# Call the delvalue method, that takes care of recursively deleting
# the subattributes, if this is a list or dictionary.
Expand Down
52 changes: 44 additions & 8 deletions aiida/djsite/db/subtests/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_with_subclasses(self):
a2.set_extra(extra_name, True)
a3 = Data().store()
a3.set_extra(extra_name, True)
a4 = ParameterData({'a':'b'}).store()
a4 = ParameterData(dict={'a':'b'}).store()
a4.set_extra(extra_name, True)
a5 = Node().store()
a5.set_extra(extra_name, True)
Expand Down Expand Up @@ -709,7 +709,7 @@ def test_attr_with_reload(self):

a.store()

b = Node(uuid=a.uuid)
b = Node.get_subclass_from_uuid(a.uuid)
self.assertIsNone(a.get_attr('none'))
self.assertEquals(self.boolval, b.get_attr('bool'))
self.assertEquals(self.intval, b.get_attr('integer'))
Expand All @@ -718,6 +718,17 @@ def test_attr_with_reload(self):
self.assertEquals(self.dictval, b.get_attr('dict'))
self.assertEquals(self.listval, b.get_attr('list'))

# Reload directly
b = Node(dbnode=a.dbnode)
self.assertIsNone(a.get_attr('none'))
self.assertEquals(self.boolval, b.get_attr('bool'))
self.assertEquals(self.intval, b.get_attr('integer'))
self.assertEquals(self.floatval, b.get_attr('float'))
self.assertEquals(self.stringval, b.get_attr('string'))
self.assertEquals(self.dictval, b.get_attr('dict'))
self.assertEquals(self.listval, b.get_attr('list'))


with self.assertRaises(ModificationNotAllowed):
a.set_attr('i',12)

Expand Down Expand Up @@ -1173,7 +1184,7 @@ def test_valid_links(self):
# I create some objects
d1 = Data().store()
with tempfile.NamedTemporaryFile() as f:
d2 = SinglefileData(f.name).store()
d2 = SinglefileData(file=f.name).store()

code = Code(remote_computer_exec=(self.computer,'/bin/true')).store()

Expand Down Expand Up @@ -1407,7 +1418,7 @@ def test_reload_singlefiledata(self):
basename = os.path.split(filename)[1]
f.write(file_content)
f.flush()
a = SinglefileData(filename)
a = SinglefileData(file=filename)

the_uuid = a.uuid

Expand Down Expand Up @@ -2079,7 +2090,23 @@ def test_reload(self):

a.store()

b = StructureData(uuid=a.uuid)
b = StructureData(dbnode=a.dbnode)

for i in range(3):
for j in range(3):
self.assertAlmostEqual(cell[i][j], b.cell[i][j])

self.assertEqual(b.pbc, (False,True,True))
self.assertEqual(len(b.sites), 2)
self.assertEqual(b.kinds[0].symbols[0], 'Ba')
self.assertEqual(b.kinds[1].symbols[0], 'Ti')
for i in range(3):
self.assertAlmostEqual(b.sites[0].position[i], 0.)
for i in range(3):
self.assertAlmostEqual(b.sites[1].position[i], 1.)

# Fully reload from UUID
b = StructureData.get_subclass_from_uuid(a.uuid)

for i in range(3):
for j in range(3):
Expand All @@ -2094,6 +2121,7 @@ def test_reload(self):
for i in range(3):
self.assertAlmostEqual(b.sites[1].position[i], 1.)


def test_copy(self):
"""
Start from a StructureData object, copy it and see if it is preserved
Expand Down Expand Up @@ -2338,15 +2366,23 @@ def test_creation(self):
self.assertEquals(second.shape, n.get_shape('second'))


n2 = ArrayData(uuid=n.uuid)

# Same checks, after reloading
n2 = ArrayData(dbnode=n.dbnode)
self.assertEquals(set(['first', 'second']), set(n2.arraynames()))
self.assertAlmostEquals(abs(first-n2.get_array('first')).max(), 0.)
self.assertAlmostEquals(abs(second-n2.get_array('second')).max(), 0.)
self.assertEquals(first.shape, n2.get_shape('first'))
self.assertEquals(second.shape, n2.get_shape('second'))

# Same checks, after reloading with UUID
n2 = ArrayData.get_subclass_from_uuid(n.uuid)
self.assertEquals(set(['first', 'second']), set(n2.arraynames()))
self.assertAlmostEquals(abs(first-n2.get_array('first')).max(), 0.)
self.assertAlmostEquals(abs(second-n2.get_array('second')).max(), 0.)
self.assertEquals(first.shape, n2.get_shape('first'))
self.assertEquals(second.shape, n2.get_shape('second'))


# Check that I cannot modify the node after storing
with self.assertRaises(ModificationNotAllowed):
n.delete_array('first')
Expand Down Expand Up @@ -2513,7 +2549,7 @@ def test_creation(self):

##############################################################
# Again, but after reloading from uuid
n = TrajectoryData(uuid=n.uuid)
n = TrajectoryData.get_subclass_from_uuid(n.uuid)
# Generic checks
self.assertEqual(n.numsites, 3)
self.assertEqual(n.numsteps, 2)
Expand Down
4 changes: 2 additions & 2 deletions aiida/djsite/db/subtests/quantumespressopw.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def test_inputs(self):
s.append_atom(position=(0.,0.,0.),symbols=['Ba'])
s.store()

p = ParameterData(input_params).store()
p = ParameterData(dict=input_params).store()

k = ParameterData(k_points).store()
k = ParameterData(dict=k_points).store()

pseudo_dir = os.path.join(aiida.__file__, os.pardir,os.pardir,
'testdata','qepseudos')
Expand Down
11 changes: 10 additions & 1 deletion aiida/djsite/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,15 @@ def get_last_daemon_run(task):
#id = str(type(TaskMeta.objects.all().order_by('-date_done')[0].task_id))
return last_run_at

# Cache for speed-up
_aiida_autouser_cache = None

def get_automatic_user():
global _aiida_autouser_cache

if _aiida_autouser_cache is not None:
return _aiida_autouser_cache

import getpass
username = getpass.getuser()

Expand All @@ -32,7 +40,8 @@ def get_automatic_user():
from aiida.common.exceptions import ConfigurationError

try:
return User.objects.get(username=username)
_aiida_autouser_cache = User.objects.get(username=username)
return _aiida_autouser_cache
except ObjectDoesNotExist:
raise ConfigurationError("No aiida user with username {}".format(
username))
Expand Down
Loading

0 comments on commit f1df00d

Please sign in to comment.