diff --git a/CHANGES.rst b/CHANGES.rst index 2a7c85c63..b9c3b72cf 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,6 +1,10 @@ Changes ======= +tip (unreleased) +---------------- +- History tracking can be inherited by passing `inherit=True`. (gh-63) + 1.7.0 (2015-12-02) ------------------ - Add ability to list history in admin when the object instance is deleted. (gh-72) diff --git a/docs/advanced.rst b/docs/advanced.rst index 3846b1758..3e9b84ea6 100644 --- a/docs/advanced.rst +++ b/docs/advanced.rst @@ -65,6 +65,35 @@ third-party apps you don't have control over. Here's an example of using register(User) +Allow tracking to be inherited +--------------------------------- + +By default history tracking is only added for the model that is passed +to ``register()`` or has the ``HistoricalRecords`` descriptor. By +passing ``inherit=True`` to either way of registering you can change +that behavior so that any child model inheriting from it will have +historical tracking as well. Be careful though, in cases where a model +can be tracked more than once, ``MultipleRegistrationsError`` will be +raised. + +.. code-block:: python + + from django.contrib.auth.models import User + from django.db import models + from simple_history import register + from simple_history.models import HistoricalRecords + + # register() example + register(User, inherit=True) + + # HistoricalRecords example + class Poll(models.Model): + history = HistoricalRecords(inherit=True) + +Both ``User`` and ``Poll`` in the example above will cause any model +inheriting from them to have historical tracking as well. + + .. recording_user: Recording Which User Changed a Model diff --git a/simple_history/__init__.py b/simple_history/__init__.py index 6c959ea2d..ec4e14cff 100755 --- a/simple_history/__init__.py +++ b/simple_history/__init__.py @@ -21,13 +21,14 @@ def register( `HistoricalManager` instance directly to `model`. """ from . import models - if model._meta.db_table not in models.registered_models: - if records_class is None: - records_class = models.HistoricalRecords - records = records_class(**records_config) - records.manager_name = manager_name - records.table_name = table_name - records.module = app and ("%s.models" % app) or model.__module__ - records.add_extra_methods(model) - records.finalize(model) - models.registered_models[model._meta.db_table] = model + + if records_class is None: + records_class = models.HistoricalRecords + + records = records_class(**records_config) + records.manager_name = manager_name + records.table_name = table_name + records.module = app and ("%s.models" % app) or model.__module__ + records.add_extra_methods(model) + records.finalize(model) + models.registered_models[model._meta.db_table] = model diff --git a/simple_history/exceptions.py b/simple_history/exceptions.py new file mode 100644 index 000000000..273dff400 --- /dev/null +++ b/simple_history/exceptions.py @@ -0,0 +1,7 @@ +""" +django-simple-history exceptions and warnings classes. +""" + +class MultipleRegistrationsError(Exception): + """The model has been registered to have history tracking more than once""" + pass diff --git a/simple_history/models.py b/simple_history/models.py index 9d79e4352..e96e8296e 100644 --- a/simple_history/models.py +++ b/simple_history/models.py @@ -26,6 +26,7 @@ add_introspection_rules( [], ["^simple_history.models.CustomForeignKeyField"]) +from . import exceptions from .manager import HistoryDescriptor registered_models = {} @@ -35,10 +36,11 @@ class HistoricalRecords(object): thread = threading.local() def __init__(self, verbose_name=None, bases=(models.Model,), - user_related_name='+', table_name=None): + user_related_name='+', table_name=None, inherit=False): self.user_set_verbose_name = verbose_name self.user_related_name = user_related_name self.table_name = table_name + self.inherit = inherit try: if isinstance(bases, six.string_types): raise TypeError @@ -49,7 +51,8 @@ def __init__(self, verbose_name=None, bases=(models.Model,), def contribute_to_class(self, cls, name): self.manager_name = name self.module = cls.__module__ - models.signals.class_prepared.connect(self.finalize, sender=cls) + self.cls = cls + models.signals.class_prepared.connect(self.finalize, weak=False) self.add_extra_methods(cls) def add_extra_methods(self, cls): @@ -69,6 +72,19 @@ def save_without_historical_record(self, *args, **kwargs): save_without_historical_record) def finalize(self, sender, **kwargs): + try: + hint_class = self.cls + except AttributeError: # called via `register` + pass + else: + if hint_class is not sender: # set in concrete + if not (self.inherit and issubclass(sender, hint_class)): # set in abstract + return + if hasattr(sender._meta, 'simple_history_manager_attribute'): + raise exceptions.MultipleRegistrationsError('{}.{} registered multiple times for history tracking.'.format( + sender._meta.app_label, + sender._meta.object_name, + )) history_model = self.create_history_model(sender) module = importlib.import_module(self.module) setattr(module, history_model.__name__, history_model) diff --git a/simple_history/tests/models.py b/simple_history/tests/models.py index 6295b82c6..7d810640a 100644 --- a/simple_history/tests/models.py +++ b/simple_history/tests/models.py @@ -277,3 +277,37 @@ class ContactRegister(models.Model): email = models.EmailField(max_length=255, unique=True) register(ContactRegister, table_name='contacts_register_history') + + +############################################################################### +# +# Inheritance examples +# +############################################################################### + +class TrackedAbstractBaseA(models.Model): + history = HistoricalRecords(inherit=True) + + class Meta: + abstract = True + + +class TrackedAbstractBaseB(models.Model): + history_b = HistoricalRecords(inherit=True) + + class Meta: + abstract = True + + +class UntrackedAbstractBase(models.Model): + + class Meta: + abstract = True + + +class TrackedConcreteBase(models.Model): + history = HistoricalRecords(inherit=True) + + +class UntrackedConcreteBase(models.Model): + pass diff --git a/simple_history/tests/tests/test_models.py b/simple_history/tests/tests/test_models.py index 61dde71c1..70d371b34 100644 --- a/simple_history/tests/tests/test_models.py +++ b/simple_history/tests/tests/test_models.py @@ -10,8 +10,8 @@ from django.test import TestCase from django.core.files.base import ContentFile +from simple_history import exceptions, register from simple_history.models import HistoricalRecords, convert_auto_field -from simple_history import register from ..models import ( AdminProfile, Bookcase, MultiOneToOne, Poll, Choice, Voter, Restaurant, Person, FileModel, Document, Book, HistoricalPoll, Library, State, @@ -19,7 +19,9 @@ ExternalModel1, ExternalModel3, UnicodeVerboseName, HistoricalChoice, HistoricalState, HistoricalCustomFKError, Series, SeriesWork, PollInfo, UserAccessorDefault, UserAccessorOverride, Employee, Country, Province, - City, Contact, ContactRegister + City, Contact, ContactRegister, + TrackedAbstractBaseA, TrackedAbstractBaseB, UntrackedAbstractBase, + TrackedConcreteBase, UntrackedConcreteBase, ) from ..external.models import ExternalModel2, ExternalModel4 @@ -332,12 +334,8 @@ def test_register_separate_app(self): self.assertEqual(len(user.histories.all()), 1) def test_reregister(self): - register(Restaurant, manager_name='again') - register(User, manager_name='again') - self.assertTrue(hasattr(Restaurant, 'updates')) - self.assertFalse(hasattr(Restaurant, 'again')) - self.assertTrue(hasattr(User, 'histories')) - self.assertFalse(hasattr(User, 'again')) + with self.assertRaises(exceptions.MultipleRegistrationsError): + register(Restaurant, manager_name='again') def test_register_custome_records(self): self.assertEqual(len(Voter.history.all()), 0) @@ -783,3 +781,73 @@ def test_custom_table_name_from_register(self): self.get_table_name(ContactRegister.history), 'contacts_register_history', ) + + +class TestTrackingInheritance(TestCase): + + def test_tracked_abstract_base(self): + class TrackedWithAbstractBase(TrackedAbstractBaseA): + pass + + self.assertEqual( + [f.attname for f in TrackedWithAbstractBase.history.model._meta.fields], + ['id', 'history_id', 'history_date', 'history_user_id', 'history_type'], + ) + + def test_tracked_concrete_base(self): + class TrackedWithConcreteBase(TrackedConcreteBase): + pass + + self.assertEqual( + [f.attname for f in TrackedWithConcreteBase.history.model._meta.fields], + ['id', 'trackedconcretebase_ptr_id', 'history_id', 'history_date', 'history_user_id', 'history_type'], + ) + + def test_multiple_tracked_bases(self): + with self.assertRaises(exceptions.MultipleRegistrationsError): + class TrackedWithMultipleAbstractBases(TrackedAbstractBaseA, TrackedAbstractBaseB): + pass + + def test_tracked_abstract_and_untracked_concrete_base(self): + class TrackedWithTrackedAbstractAndUntrackedConcreteBase(TrackedAbstractBaseA, UntrackedConcreteBase): + pass + + self.assertEqual( + [f.attname for f in TrackedWithTrackedAbstractAndUntrackedConcreteBase.history.model._meta.fields], + ['id', 'untrackedconcretebase_ptr_id', 'history_id', 'history_date', 'history_user_id', 'history_type'], + ) + + def test_indirect_tracked_abstract_base(self): + class BaseTrackedWithIndirectTrackedAbstractBase(TrackedAbstractBaseA): + pass + + class TrackedWithIndirectTrackedAbstractBase(BaseTrackedWithIndirectTrackedAbstractBase): + pass + + self.assertEqual( + [f.attname for f in TrackedWithIndirectTrackedAbstractBase.history.model._meta.fields], + [ + 'id', 'basetrackedwithindirecttrackedabstractbase_ptr_id', + 'history_id', 'history_date', 'history_user_id', 'history_type'], + ) + + def test_indirect_tracked_concrete_base(self): + class BaseTrackedWithIndirectTrackedConcreteBase(TrackedAbstractBaseA): + pass + + class TrackedWithIndirectTrackedConcreteBase(BaseTrackedWithIndirectTrackedConcreteBase): + pass + + self.assertEqual( + [f.attname for f in TrackedWithIndirectTrackedConcreteBase.history.model._meta.fields], + [ + 'id', 'basetrackedwithindirecttrackedconcretebase_ptr_id', + 'history_id', 'history_date', 'history_user_id', 'history_type'], + ) + + def test_registering_with_tracked_abstract_base(self): + class TrackedWithAbstractBaseToRegister(TrackedAbstractBaseA): + pass + + with self.assertRaises(exceptions.MultipleRegistrationsError): + register(TrackedWithAbstractBaseToRegister)