Skip to content

Commit

Permalink
Avoid unnecessary errors when launching Wannier90OptimizeWorkChain
Browse files Browse the repository at this point in the history
…related to `separate_plotting` input.

Replace the errors with user warnings and extend the warning messages.

1. Automatically set `separate_plotting=True` if `optimize_disproj == True` and `wannier_plot == True`
2. Replace the error with a warning if `separate_plotting == True` but no plotting input is specified
3. Throw an error if `wannier_plot == True` but `separate_plotting == False`
  • Loading branch information
npaulish authored and qiaojunfeng committed Dec 13, 2023
1 parent c354a7f commit cf67a03
Showing 1 changed file with 31 additions and 11 deletions.
42 changes: 31 additions & 11 deletions src/aiida_wannier90_workflows/workflows/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,32 @@ def validate_inputs(inputs, ctx=None): # pylint: disable=unused-argument
separate_plotting = inputs.get("separate_plotting", False)
plot_inputs = [
parameters.get(_, False)
# for _ in Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS # pylint: disable=protected-access
for _ in ["wannier_plot"]
for _ in Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS # pylint: disable=protected-access
]
if separate_plotting:
if not any(plot_inputs):
return (
warnings.warn(
"Trying to separate plotting routines but no "
f"{'/'.join(Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS)} in wannier90 parameters?" # pylint: disable=protected-access
f"{'/'.join(Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS)} in wannier90 parameters. " # pylint: disable=protected-access
"The unnecessary plotting step can be avoided "
"by setting `builder.separate_plotting = False` and resubmitting the workchain."
)
else:
if optimize_disproj and parameters.get("wannier_plot", False):
return (
"Detected wannier_plot = True in wannier90 parameters, "
"but `separate_plotting = False`. For optimizing projectability "
"disentanglement (`optimize_disproj = True` ), it is highly recommended "
"to plot Wannier functions in a separate step to reduce computational time. "
"To do so, set `builder.separate_plotting = True` before submitting the workchain."
)
if optimize_disproj and any(plot_inputs):
warnings.warn(
"`optimize_disproj = True` but `separate_plotting = False`. For optimizing projectability "
"disentanglement, it is highly recommended to run the plotting mode in a separate step."
f"Detected a plotting input ({'/'.join(Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS)}) "
"in wannier90 parameters, but `separate_plotting = False`. For optimizing projectability "
"disentanglement (`optimize_disproj = True` ), it is recommended to run the plotting mode "
"in a separate step to reduce computational time. "
"This can be done by setting `builder.separate_plotting = True` and resubmitting the workchain."
)

return None
Expand All @@ -74,9 +86,12 @@ class Wannier90OptimizeWorkChain(Wannier90BandsWorkChain):

# The following keys are for wannier90.x plotting, i.e. they can be restarted from
# chk file by setting `restart = plot` in wannier90.win.
# the `bands_plot` is commented out since it is rather cheap to compute,
# and also we want to check the band distance during each iteration of optimizing dis_proj_min/max
# even if `separate_plotting = True`
_WANNIER90_PLOT_INPUTS = (
"wannier_plot",
"bands_plot",
# "bands_plot",
"write_tb",
"write_hr",
"write_hhmn",
Expand Down Expand Up @@ -301,11 +316,19 @@ def get_builder_from_protocol( # pylint: disable=arguments-differ

# Inputs for optimizing dis_proj_min/max
if reference_bands:
builder.separate_plotting = True
builder.optimize_disproj = True
builder.optimize_reference_bands = reference_bands
builder.optimize_bands_distance_threshold = bands_distance_threshold

if builder.optimize_disproj:
# If optimizing dis_proj_min/max,
# make sure Wannier functions are plotted in a separate step since it is heavy
if (
builder["wannier90"]["wannier90"]["parameters"]
.get_dict()
.get("wannier_plot", False)
):
builder.separate_plotting = True
return builder

def setup(self):
Expand All @@ -332,9 +355,6 @@ def setup(self):
parameters = self.inputs.wannier90.wannier90["parameters"].get_dict()
# I convert the tuple to list so it can be changed
excluded_inputs = list(Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS)
# I need to calculate bands for comparing bands distance
if "optimize_reference_bands" in self.inputs:
excluded_inputs.remove("bands_plot")
for key in excluded_inputs:
plot_input = parameters.get(key, False)
if plot_input:
Expand Down

0 comments on commit cf67a03

Please sign in to comment.