Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid unnecessary errors when launching Wannier90OptimizeWorkChain #40

Merged
merged 1 commit into from
Dec 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions src/aiida_wannier90_workflows/workflows/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,32 @@ def validate_inputs(inputs, ctx=None): # pylint: disable=unused-argument
separate_plotting = inputs.get("separate_plotting", False)
plot_inputs = [
parameters.get(_, False)
# for _ in Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS # pylint: disable=protected-access
for _ in ["wannier_plot"]
for _ in Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS # pylint: disable=protected-access
]
if separate_plotting:
if not any(plot_inputs):
return (
warnings.warn(
"Trying to separate plotting routines but no "
f"{'/'.join(Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS)} in wannier90 parameters?" # pylint: disable=protected-access
f"{'/'.join(Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS)} in wannier90 parameters. " # pylint: disable=protected-access
"The unnecessary plotting step can be avoided "
"by setting `builder.separate_plotting = False` and resubmitting the workchain."
)
else:
if optimize_disproj and parameters.get("wannier_plot", False):
return (
"Detected wannier_plot = True in wannier90 parameters, "
"but `separate_plotting = False`. For optimizing projectability "
"disentanglement (`optimize_disproj = True` ), it is highly recommended "
"to plot Wannier functions in a separate step to reduce computational time. "
"To do so, set `builder.separate_plotting = True` before submitting the workchain."
)
if optimize_disproj and any(plot_inputs):
warnings.warn(
"`optimize_disproj = True` but `separate_plotting = False`. For optimizing projectability "
"disentanglement, it is highly recommended to run the plotting mode in a separate step."
f"Detected a plotting input ({'/'.join(Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS)}) "
"in wannier90 parameters, but `separate_plotting = False`. For optimizing projectability "
"disentanglement (`optimize_disproj = True` ), it is recommended to run the plotting mode "
"in a separate step to reduce computational time. "
"This can be done by setting `builder.separate_plotting = True` and resubmitting the workchain."
)

return None
Expand All @@ -74,9 +86,12 @@ class Wannier90OptimizeWorkChain(Wannier90BandsWorkChain):

# The following keys are for wannier90.x plotting, i.e. they can be restarted from
# chk file by setting `restart = plot` in wannier90.win.
# the `bands_plot` is commented out since it is rather cheap to compute,
# and also we want to check the band distance during each iteration of optimizing dis_proj_min/max
# even if `separate_plotting = True`
_WANNIER90_PLOT_INPUTS = (
"wannier_plot",
"bands_plot",
# "bands_plot",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great if you could add some explanations before the commented bands_plot, sth like

the bands_plot is commented out since it is rather cheap to compute, and also we want to check the band distance during each iteration of optimizing dis_proj_min/max even if we set `separate_plotting = True`

"write_tb",
"write_hr",
"write_hhmn",
Expand Down Expand Up @@ -301,11 +316,19 @@ def get_builder_from_protocol( # pylint: disable=arguments-differ

# Inputs for optimizing dis_proj_min/max
if reference_bands:
builder.separate_plotting = True
builder.optimize_disproj = True
builder.optimize_reference_bands = reference_bands
builder.optimize_bands_distance_threshold = bands_distance_threshold

if builder.optimize_disproj:
# If optimizing dis_proj_min/max,
# make sure Wannier functions are plotted in a separate step since it is heavy
if (
builder["wannier90"]["wannier90"]["parameters"]
.get_dict()
.get("wannier_plot", False)
):
builder.separate_plotting = True
return builder

def setup(self):
Expand All @@ -332,9 +355,6 @@ def setup(self):
parameters = self.inputs.wannier90.wannier90["parameters"].get_dict()
# I convert the tuple to list so it can be changed
excluded_inputs = list(Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS)
# I need to calculate bands for comparing bands distance
if "optimize_reference_bands" in self.inputs:
excluded_inputs.remove("bands_plot")
for key in excluded_inputs:
plot_input = parameters.get(key, False)
if plot_input:
Expand Down