Skip to content

Commit

Permalink
Fixed reverse-fields that was added after Django 1.10
Browse files Browse the repository at this point in the history
  • Loading branch information
mostafa-mahmoud committed Nov 4, 2016
1 parent d60fa4d commit 61c1d26
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 18 deletions.
28 changes: 18 additions & 10 deletions djenerator/generate_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from model_reader import is_auto_field
from model_reader import is_related
from model_reader import is_required
from model_reader import is_reverse_related
from model_reader import list_of_fields
from model_reader import list_of_models
from model_reader import module_import
Expand All @@ -33,7 +34,10 @@ def field_sample_values(field):
"""
list_field_values = []
if not is_auto_field(field):
if is_related(field):
if is_reverse_related(field):
# TODO(mostafa-mahmoud): Check if this case needs to be handled.
pass
elif is_related(field):
model = field.rel.to
list_field_values = list(model.objects.all())
if 'ManyToMany' in relation_type(field) and list_field_values:
Expand All @@ -50,7 +54,9 @@ def field_sample_values(field):
found = True
input_method = model.TestData.__dict__[field.name]
if isinstance(input_method, str):
input_file = open('TestTemplates/' + input_method, 'r')
app_name = field.model._meta.app_label
path = '%s/TestTemplates/%s' % (app_name, input_method)
input_file = open(path, 'r')
list_field_values = [word[:-1] for word in input_file]
elif (isinstance(input_method, list)
or isinstance(input_method, tuple)):
Expand All @@ -59,8 +65,9 @@ def field_sample_values(field):
if inspect.isfunction(input_method):
list_field_values = input_method()
if not found:
path = 'TestTemplates/sample__%s__%s' % (field.model.__name__,
field.name)
app_name = field.model._meta.app_label
path = '%s/TestTemplates/sample__%s__%s' % (app_name,
field.model.__name__, field.name)
input_file = open(path, 'r')
list_field_values = [word[:-1] for word in input_file]
# TODO(mostafa-mahmoud) : Generate totally randomized
Expand Down Expand Up @@ -158,7 +165,8 @@ def generate_model(model, size, shuffle=None):
and list of field that's not computed.
"""
unique_fields = [(field.name,) for field in list_of_fields(model)
if field.unique and not is_auto_field(field)]
if (hasattr(field, 'unique') and field.unique
and not is_auto_field(field))]
unique_together = []
if hasattr(model._meta, 'unique_together'):
unique_together = list(model._meta.unique_together)
Expand Down Expand Up @@ -191,9 +199,9 @@ def create_model(model, val):
A model with the values given.
"""
vals_dictionary = dict(val)
have_many_to_many_relation = any([x for x in list_of_fields(model)
if is_related(x)
and 'ManyToMany' in relation_type(x)])
have_many_to_many_relation = any(x for x in list_of_fields(model)
if (is_related(x) and
'ManyToMany' in relation_type(x)))
if not have_many_to_many_relation:
mdl = model(**vals_dictionary)
mdl.save()
Expand Down Expand Up @@ -283,8 +291,8 @@ def recompute(model, field):
n = len(list_field_values)
for index, mdl in enumerate(models):
if ('ManyToMany' in relation_type(field) and
not getattr(mdl, field.name).exists() or
not is_required(field) and not getattr(mdl, field.name)):
not getattr(mdl, field.name).exists() or
not is_required(field) and not getattr(mdl, field.name)):
setattr(mdl, field.name, list_field_values[index % n])
mdl.save()

Expand Down
25 changes: 22 additions & 3 deletions djenerator/model_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ def is_related(field):
return 'django.db.models.fields.related' in field.__module__


def is_reverse_related(field):
""" Is a reverse-related field
Test if a given field is a reverse related field.
Args:
field : A reference to the class of a given field.
Returns:
A boolean value that is true only if the field is reverse related.
"""
return 'django.db.models.fields.reverse_related' in field.__module__


def field_type(field):
""" Field Type
Expand Down Expand Up @@ -138,11 +152,15 @@ def list_of_fields(model):
Returns:
A list of references to the fields of the given model.
"""
fields = list(model._meta._get_fields())
fields = list(filter(lambda field: not is_reverse_related(field), fields))
"""
if (hasattr(model._meta, '_fields')
and hasattr(model._meta, '_many_to_many')):
fields = model._meta._fields() + model._meta._many_to_many()
else:
fields = model._fields
"""
# If the inheritance is multi-table inheritence, the fields of
# the super class(that should be inherited) will not appear
# in fields, and they will be replaced by a OneToOneField to the
Expand All @@ -152,9 +170,10 @@ def list_of_fields(model):
if Model != model.__base__:
clone = [fld for fld in fields]
for field in clone:
if (is_related(field) and ('OneToOne' in relation_type(field)
or 'ManyToMany' in relation_type(field)) and (field.rel.to in
model.__bases__) and field.rel.to != model):
if (is_related(field) and
('OneToOne' in relation_type(field) or
'ManyToMany' in relation_type(field)) and
(field.rel.to in model.__bases__) and field.rel.to != model):
fields.remove(field)
fields += filter(lambda x: not is_auto_field(x),
list_of_fields(field.rel.to))
Expand Down
21 changes: 16 additions & 5 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from djenerator.model_reader import is_instance_of_django_model
from djenerator.model_reader import is_related
from djenerator.model_reader import is_required
from djenerator.model_reader import is_reverse_related
from djenerator.model_reader import list_of_fields
from djenerator.model_reader import list_of_models
from djenerator.model_reader import module_import
Expand Down Expand Up @@ -483,7 +484,8 @@ def test(self):
fields = list_of_fields(model.__class__)
nodes += 1
for field in fields:
if not is_auto_field(field):
if (not is_auto_field(field) and
not is_reverse_related(field)):
val = getattr(model, field.name)
if is_related(field):
if 'ManyToMany' in relation_type(field):
Expand All @@ -496,10 +498,19 @@ def test(self):
self.assertTrue(val in r)
edges += 1
else:
sample_values = map(lambda x: str(x),
field_sample_values(field))
val = str(val)
self.assertTrue(val in sample_values)
if (field.__class__.__name__ == 'DecimalField' or
field.__class__.__name__ == 'FloatField'):
sample_values = map(float,
field_sample_values(field))
val = float(val)
self.assertTrue(any(abs(val - fld_value) < 1e-5
for fld_value in
sample_values))
else:
sample_values = map(str,
field_sample_values(field))
val = str(val)
self.assertTrue(val in sample_values)
if model.__class__ == TestModelFields:
pr = (model.fieldC, model.fieldA)
self.assertFalse(pr in pairs)
Expand Down

0 comments on commit 61c1d26

Please sign in to comment.