diff --git a/djenerator/generate_test_data.py b/djenerator/generate_test_data.py index ff262d5..de34a89 100644 --- a/djenerator/generate_test_data.py +++ b/djenerator/generate_test_data.py @@ -12,6 +12,7 @@ from model_reader import is_auto_field from model_reader import is_related from model_reader import is_required +from model_reader import is_reverse_related from model_reader import list_of_fields from model_reader import list_of_models from model_reader import module_import @@ -33,7 +34,10 @@ def field_sample_values(field): """ list_field_values = [] if not is_auto_field(field): - if is_related(field): + if is_reverse_related(field): + # TODO(mostafa-mahmoud): Check if this case needs to be handled. + pass + elif is_related(field): model = field.rel.to list_field_values = list(model.objects.all()) if 'ManyToMany' in relation_type(field) and list_field_values: @@ -50,7 +54,9 @@ def field_sample_values(field): found = True input_method = model.TestData.__dict__[field.name] if isinstance(input_method, str): - input_file = open('TestTemplates/' + input_method, 'r') + app_name = field.model._meta.app_label + path = '%s/TestTemplates/%s' % (app_name, input_method) + input_file = open(path, 'r') list_field_values = [word[:-1] for word in input_file] elif (isinstance(input_method, list) or isinstance(input_method, tuple)): @@ -59,8 +65,9 @@ def field_sample_values(field): if inspect.isfunction(input_method): list_field_values = input_method() if not found: - path = 'TestTemplates/sample__%s__%s' % (field.model.__name__, - field.name) + app_name = field.model._meta.app_label + path = '%s/TestTemplates/sample__%s__%s' % (app_name, + field.model.__name__, field.name) input_file = open(path, 'r') list_field_values = [word[:-1] for word in input_file] # TODO(mostafa-mahmoud) : Generate totally randomized @@ -158,7 +165,8 @@ def generate_model(model, size, shuffle=None): and list of field that's not computed. """ unique_fields = [(field.name,) for field in list_of_fields(model) - if field.unique and not is_auto_field(field)] + if (hasattr(field, 'unique') and field.unique + and not is_auto_field(field))] unique_together = [] if hasattr(model._meta, 'unique_together'): unique_together = list(model._meta.unique_together) @@ -191,9 +199,9 @@ def create_model(model, val): A model with the values given. """ vals_dictionary = dict(val) - have_many_to_many_relation = any([x for x in list_of_fields(model) - if is_related(x) - and 'ManyToMany' in relation_type(x)]) + have_many_to_many_relation = any(x for x in list_of_fields(model) + if (is_related(x) and + 'ManyToMany' in relation_type(x))) if not have_many_to_many_relation: mdl = model(**vals_dictionary) mdl.save() @@ -283,8 +291,8 @@ def recompute(model, field): n = len(list_field_values) for index, mdl in enumerate(models): if ('ManyToMany' in relation_type(field) and - not getattr(mdl, field.name).exists() or - not is_required(field) and not getattr(mdl, field.name)): + not getattr(mdl, field.name).exists() or + not is_required(field) and not getattr(mdl, field.name)): setattr(mdl, field.name, list_field_values[index % n]) mdl.save() diff --git a/djenerator/model_reader.py b/djenerator/model_reader.py index 3ad8800..44dade5 100644 --- a/djenerator/model_reader.py +++ b/djenerator/model_reader.py @@ -54,6 +54,20 @@ def is_related(field): return 'django.db.models.fields.related' in field.__module__ +def is_reverse_related(field): + """ Is a reverse-related field + + Test if a given field is a reverse related field. + + Args: + field : A reference to the class of a given field. + + Returns: + A boolean value that is true only if the field is reverse related. + """ + return 'django.db.models.fields.reverse_related' in field.__module__ + + def field_type(field): """ Field Type @@ -138,11 +152,15 @@ def list_of_fields(model): Returns: A list of references to the fields of the given model. """ + fields = list(model._meta._get_fields()) + fields = list(filter(lambda field: not is_reverse_related(field), fields)) + """ if (hasattr(model._meta, '_fields') and hasattr(model._meta, '_many_to_many')): fields = model._meta._fields() + model._meta._many_to_many() else: fields = model._fields + """ # If the inheritance is multi-table inheritence, the fields of # the super class(that should be inherited) will not appear # in fields, and they will be replaced by a OneToOneField to the @@ -152,9 +170,10 @@ def list_of_fields(model): if Model != model.__base__: clone = [fld for fld in fields] for field in clone: - if (is_related(field) and ('OneToOne' in relation_type(field) - or 'ManyToMany' in relation_type(field)) and (field.rel.to in - model.__bases__) and field.rel.to != model): + if (is_related(field) and + ('OneToOne' in relation_type(field) or + 'ManyToMany' in relation_type(field)) and + (field.rel.to in model.__bases__) and field.rel.to != model): fields.remove(field) fields += filter(lambda x: not is_auto_field(x), list_of_fields(field.rel.to)) diff --git a/tests/tests.py b/tests/tests.py index 1e6dc6e..cef003b 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -40,6 +40,7 @@ from djenerator.model_reader import is_instance_of_django_model from djenerator.model_reader import is_related from djenerator.model_reader import is_required +from djenerator.model_reader import is_reverse_related from djenerator.model_reader import list_of_fields from djenerator.model_reader import list_of_models from djenerator.model_reader import module_import @@ -483,7 +484,8 @@ def test(self): fields = list_of_fields(model.__class__) nodes += 1 for field in fields: - if not is_auto_field(field): + if (not is_auto_field(field) and + not is_reverse_related(field)): val = getattr(model, field.name) if is_related(field): if 'ManyToMany' in relation_type(field): @@ -496,10 +498,19 @@ def test(self): self.assertTrue(val in r) edges += 1 else: - sample_values = map(lambda x: str(x), - field_sample_values(field)) - val = str(val) - self.assertTrue(val in sample_values) + if (field.__class__.__name__ == 'DecimalField' or + field.__class__.__name__ == 'FloatField'): + sample_values = map(float, + field_sample_values(field)) + val = float(val) + self.assertTrue(any(abs(val - fld_value) < 1e-5 + for fld_value in + sample_values)) + else: + sample_values = map(str, + field_sample_values(field)) + val = str(val) + self.assertTrue(val in sample_values) if model.__class__ == TestModelFields: pr = (model.fieldC, model.fieldA) self.assertFalse(pr in pairs)