Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 18 additions & 26 deletions formtools/wizard/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __repr__(self):
@property
def all(self):
"Returns the names of all steps/forms."
return list(self._wizard.get_form_list())
return list(self._wizard.form_list)

@property
def count(self):
Expand Down Expand Up @@ -201,28 +201,21 @@ def get_prefix(self, request, *args, **kwargs):
# TODO: Add some kind of unique id to prefix
return normalize_name(self.__class__.__name__)

def get_form_list(self):
def process_condition_dict(self):
"""
This method returns a form_list based on the initial form list but
checks if there is a condition method/value in the condition_list.
If an entry exists in the condition list, it will call/read the value
and respect the result. (True means add the form, False means ignore
the form)

The form_list is always generated on the fly because condition methods
could use data from other (maybe previous forms).
This method prunes `self.form_list` by checking if there is a condition method/value in `condition_list`.
If an entry exists, it will call/read the value and respect the result. If the condition returns False, the
form will be removed from `form_list`.
"""
form_list = OrderedDict()
for form_key, form_class in self.form_list.items():
for form_key in list(self.form_list.keys()):
# try to fetch the value from condition list, by default, the form
# gets passed to the new list.
condition = self.condition_dict.get(form_key, True)
if callable(condition):
# call the value if needed, passes the current instance.
condition = condition(self)
if condition:
form_list[form_key] = form_class
return form_list
if not condition:
del self.form_list[form_key]

def dispatch(self, request, *args, **kwargs):
"""
Expand All @@ -241,6 +234,7 @@ def dispatch(self, request, *args, **kwargs):
getattr(self, 'file_storage', None),
)
self.steps = StepsHelper(self)
self.process_condition_dict()
response = super().dispatch(request, *args, **kwargs)

# update the response (e.g. adding cookies)
Expand Down Expand Up @@ -273,7 +267,7 @@ def post(self, *args, **kwargs):
# contains a valid step name. If one was found, render the requested
# form. (This makes stepping back a lot easier).
wizard_goto_step = self.request.POST.get('wizard_goto_step', None)
if wizard_goto_step and wizard_goto_step in self.get_form_list():
if wizard_goto_step and wizard_goto_step in self.form_list:
return self.render_goto_step(wizard_goto_step)

# Check if form was refreshed
Expand Down Expand Up @@ -342,7 +336,7 @@ def render_done(self, form, **kwargs):
"""
final_forms = OrderedDict()
# walk through the form list and try to validate the data again.
for form_key in self.get_form_list():
for form_key in self.form_list.keys():
form_obj = self.get_form(
step=form_key,
data=self.storage.get_step_data(form_key),
Expand Down Expand Up @@ -406,7 +400,7 @@ def get_form(self, step=None, data=None, files=None):
"""
if step is None:
step = self.steps.current
form_class = self.get_form_list()[step]
form_class = self.form_list[step]
# prepare the kwargs for the form instance.
kwargs = self.get_form_kwargs(step)
kwargs.update({
Expand Down Expand Up @@ -469,7 +463,7 @@ def get_all_cleaned_data(self):
'formset-' and contain a list of the formset cleaned_data dictionaries.
"""
cleaned_data = {}
for form_key in self.get_form_list():
for form_key in self.form_list.keys():
form_obj = self.get_form(
step=form_key,
data=self.storage.get_step_data(form_key),
Expand Down Expand Up @@ -510,8 +504,7 @@ def get_next_step(self, step=None):
"""
if step is None:
step = self.steps.current
form_list = self.get_form_list()
keys = list(form_list.keys())
keys = list(self.form_list.keys())
if step not in keys:
return self.steps.first
key = keys.index(step) + 1
Expand All @@ -529,8 +522,7 @@ def get_prev_step(self, step=None):
"""
if step is None:
step = self.steps.current
form_list = self.get_form_list()
keys = list(form_list.keys())
keys = list(self.form_list.keys())
if step not in keys:
return None
key = keys.index(step) - 1
Expand All @@ -547,7 +539,7 @@ def get_step_index(self, step=None):
"""
if step is None:
step = self.steps.current
keys = list(self.get_form_list().keys())
keys = list(self.form_list.keys())
if step in keys:
return keys.index(step)
return None
Expand Down Expand Up @@ -678,7 +670,7 @@ def get(self, *args, **kwargs):
)
return self.render(form, **kwargs)

elif step_url in self.get_form_list():
elif step_url in self.form_list:
self.storage.current_step = step_url
return self.render(
self.get_form(
Expand All @@ -699,7 +691,7 @@ def post(self, *args, **kwargs):
is super'd from WizardView.
"""
wizard_goto_step = self.request.POST.get('wizard_goto_step', None)
if wizard_goto_step and wizard_goto_step in self.get_form_list():
if wizard_goto_step and wizard_goto_step in self.form_list:
return self.render_goto_step(wizard_goto_step)
return super().post(*args, **kwargs)

Expand Down
34 changes: 10 additions & 24 deletions tests/wizard/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,12 @@ def done(self, form_list, **kwargs):


class TestWizardWithCustomGetFormList(TestWizard):
form_list = [('start', Step1)]

form_list = [Step1]

def get_form_list(self):
return {'start': Step1, 'step2': Step2}
def process_condition_dict(self):
super().process_condition_dict()
# Modify the `form_list` using any criteria (e.g. whether the user is logged in, etc.) or none at all
self.form_list['step2'] = Step2


class FormTests(TestCase):
Expand Down Expand Up @@ -158,19 +159,6 @@ def test_form_condition(self):
response, instance = testform(request)
self.assertEqual(instance.get_next_step(), 'step2')

def test_form_condition_unstable(self):
request = get_request()
testform = TestWizard.as_view(
[('start', Step1), ('step2', Step2), ('step3', Step3)],
condition_dict={'step2': True}
)
response, instance = testform(request)
self.assertEqual(instance.get_step_index('step2'), 1)
self.assertEqual(instance.get_next_step('step2'), 'step3')
instance.condition_dict['step2'] = False
self.assertEqual(instance.get_step_index('step2'), None)
self.assertEqual(instance.get_next_step('step2'), 'start')

def test_form_kwargs(self):
request = get_request()
testform = TestWizard.as_view([
Expand Down Expand Up @@ -265,23 +253,21 @@ def test_form_list_type(self):
response, instance = testform(request)
self.assertEqual(response.status_code, 200)

def test_get_form_list_default(self):
def test_form_list_default(self):
request = get_request()
testform = TestWizard.as_view([('start', Step1)])
response, instance = testform(request)

form_list = instance.get_form_list()
self.assertEqual(form_list, {'start': Step1})
self.assertEqual(instance.form_list, {'start': Step1})
with self.assertRaises(KeyError):
instance.get_form('step2')

def test_get_form_list_custom(self):
def test_form_list_custom(self):
request = get_request()
testform = TestWizardWithCustomGetFormList.as_view([('start', Step1)])
testform = TestWizardWithCustomGetFormList.as_view()
response, instance = testform(request)

form_list = instance.get_form_list()
self.assertEqual(form_list, {'start': Step1, 'step2': Step2})
self.assertEqual(instance.form_list, {'start': Step1, 'step2': Step2})
self.assertIsInstance(instance.get_form('step2'), Step2)


Expand Down