Skip to content

Commit 53e1602

Browse files
JCoxwellbckohan
authored andcommitted
Add create_from_super method and test
1 parent c541020 commit 53e1602

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

src/polymorphic/query.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,3 +531,35 @@ def get_real_instances(self, base_result_objects=None):
531531
return olist
532532
clist = PolymorphicQuerySet._p_list_class(olist)
533533
return clist
534+
535+
def create_from_super(self, obj, **kwargs):
536+
"""Creates an instance of self.model (cls) from existing super class.
537+
The new subclass will be the same object with same database id
538+
and data as obj, but will be an instance of cls.
539+
540+
obj must be an instance of the direct superclass of cls.
541+
kwargs should contain all required fields of the subclass (cls).
542+
543+
returns obj as an instance of cls.
544+
"""
545+
cls = self.model
546+
import inspect
547+
548+
scls = inspect.getmro(cls)[1]
549+
if scls is not type(obj):
550+
raise Exception(
551+
"create_from_super can only be used if obj is one level of inheritance up from cls"
552+
)
553+
ptr = "{}_ptr_id".format(scls.__name__.lower())
554+
kwargs[ptr] = obj.id
555+
# create the new base class with only fields that apply to it.
556+
nobj = cls(**kwargs)
557+
nobj.save_base(raw=True)
558+
# force update the content type, but first we need to
559+
# retrieve a clean copy from the db to fill in the null
560+
# fields otherwise they would be overwritten.
561+
nobj = cls.objects.get(pk=obj.pk)
562+
nobj.polymorphic_ctype = ContentType.objects.get_for_model(cls)
563+
nobj.save()
564+
565+
return nobj.get_real_instance() # cast to cls
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from django.test import TransactionTestCase
2+
from polymorphic.tests.models import Model2A, Model2B, Model2C, Model2D
3+
4+
5+
class PolymorphicTests(TransactionTestCase):
6+
def test_create_from_super(self):
7+
# run create test 3 times because initial implementation
8+
# would fail after first success.
9+
for i in range(3):
10+
mc = Model2C.objects.create(
11+
field1="C1{}".format(i), field2="C2{}".format(i), field3="C3{}".format(i)
12+
)
13+
mc.save()
14+
field4 = "D4{}".format(i)
15+
md = Model2D.objects.create_from_super(mc, field4=field4)
16+
self.assertEqual(mc.id, md.id)
17+
self.assertEqual(mc.field1, md.field1)
18+
self.assertEqual(mc.field2, md.field2)
19+
self.assertEqual(mc.field3, md.field3)
20+
self.assertEqual(md.field4, field4)
21+
ma = Model2A.objects.create(field1="A1e")
22+
self.assertRaises(Exception, Model2D.objects.create_from_super, ma, field4="D4e")
23+
mb = Model2B.objects.create(field1="B1e", field2="B2e")
24+
self.assertRaises(Exception, Model2D.objects.create_from_super, mb, field4="D4e")

0 commit comments

Comments
 (0)