From cf67a0344582aca3b90f08c2e0b99df77fbc79d3 Mon Sep 17 00:00:00 2001 From: npaulish Date: Wed, 13 Dec 2023 11:50:02 +0100 Subject: [PATCH] Avoid unnecessary errors when launching `Wannier90OptimizeWorkChain` related to `separate_plotting` input. Replace the errors with user warnings and extend the warning messages. 1. Automatically set `separate_plotting=True` if `optimize_disproj == True` and `wannier_plot == True` 2. Replace the error with a warning if `separate_plotting == True` but no plotting input is specified 3. Throw an error if `wannier_plot == True` but `separate_plotting == False` --- .../workflows/optimize.py | 42 ++++++++++++++----- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/src/aiida_wannier90_workflows/workflows/optimize.py b/src/aiida_wannier90_workflows/workflows/optimize.py index 9c91ce6..298117c 100644 --- a/src/aiida_wannier90_workflows/workflows/optimize.py +++ b/src/aiida_wannier90_workflows/workflows/optimize.py @@ -50,20 +50,32 @@ def validate_inputs(inputs, ctx=None): # pylint: disable=unused-argument separate_plotting = inputs.get("separate_plotting", False) plot_inputs = [ parameters.get(_, False) - # for _ in Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS # pylint: disable=protected-access - for _ in ["wannier_plot"] + for _ in Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS # pylint: disable=protected-access ] if separate_plotting: if not any(plot_inputs): - return ( + warnings.warn( "Trying to separate plotting routines but no " - f"{'/'.join(Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS)} in wannier90 parameters?" # pylint: disable=protected-access + f"{'/'.join(Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS)} in wannier90 parameters. " # pylint: disable=protected-access + "The unnecessary plotting step can be avoided " + "by setting `builder.separate_plotting = False` and resubmitting the workchain." ) else: + if optimize_disproj and parameters.get("wannier_plot", False): + return ( + "Detected wannier_plot = True in wannier90 parameters, " + "but `separate_plotting = False`. For optimizing projectability " + "disentanglement (`optimize_disproj = True` ), it is highly recommended " + "to plot Wannier functions in a separate step to reduce computational time. " + "To do so, set `builder.separate_plotting = True` before submitting the workchain." + ) if optimize_disproj and any(plot_inputs): warnings.warn( - "`optimize_disproj = True` but `separate_plotting = False`. For optimizing projectability " - "disentanglement, it is highly recommended to run the plotting mode in a separate step." + f"Detected a plotting input ({'/'.join(Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS)}) " + "in wannier90 parameters, but `separate_plotting = False`. For optimizing projectability " + "disentanglement (`optimize_disproj = True` ), it is recommended to run the plotting mode " + "in a separate step to reduce computational time. " + "This can be done by setting `builder.separate_plotting = True` and resubmitting the workchain." ) return None @@ -74,9 +86,12 @@ class Wannier90OptimizeWorkChain(Wannier90BandsWorkChain): # The following keys are for wannier90.x plotting, i.e. they can be restarted from # chk file by setting `restart = plot` in wannier90.win. + # the `bands_plot` is commented out since it is rather cheap to compute, + # and also we want to check the band distance during each iteration of optimizing dis_proj_min/max + # even if `separate_plotting = True` _WANNIER90_PLOT_INPUTS = ( "wannier_plot", - "bands_plot", + # "bands_plot", "write_tb", "write_hr", "write_hhmn", @@ -301,11 +316,19 @@ def get_builder_from_protocol( # pylint: disable=arguments-differ # Inputs for optimizing dis_proj_min/max if reference_bands: - builder.separate_plotting = True builder.optimize_disproj = True builder.optimize_reference_bands = reference_bands builder.optimize_bands_distance_threshold = bands_distance_threshold + if builder.optimize_disproj: + # If optimizing dis_proj_min/max, + # make sure Wannier functions are plotted in a separate step since it is heavy + if ( + builder["wannier90"]["wannier90"]["parameters"] + .get_dict() + .get("wannier_plot", False) + ): + builder.separate_plotting = True return builder def setup(self): @@ -332,9 +355,6 @@ def setup(self): parameters = self.inputs.wannier90.wannier90["parameters"].get_dict() # I convert the tuple to list so it can be changed excluded_inputs = list(Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS) - # I need to calculate bands for comparing bands distance - if "optimize_reference_bands" in self.inputs: - excluded_inputs.remove("bands_plot") for key in excluded_inputs: plot_input = parameters.get(key, False) if plot_input: