Skip to content

Commit

Permalink
Shift bands with different Fermi energies
Browse files Browse the repository at this point in the history
When plotting the band structure comparison.
  • Loading branch information
qiaojunfeng committed May 1, 2024
1 parent 97bbe16 commit acdaadf
Showing 1 changed file with 27 additions and 16 deletions.
43 changes: 27 additions & 16 deletions src/aiida_wannier90_workflows/utils/workflows/plot/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def get_mpl_code_for_bands(
wan_bands,
*,
fermi_energy=None,
fermi_energy2=None,
shift_fermi=False,
title=None,
save=False,
Expand Down Expand Up @@ -172,6 +173,14 @@ def get_mpl_code_for_bands(
replacement += "p.axhline(y=0, color='blue', linestyle='--', label='Fermi', zorder=-1)\n"
else:
replacement += "p.axhline(y=fermi_energy, color='blue', linestyle='--', label='Fermi', zorder=-1)\n"
if (fermi_energy2 is not None) and abs(fermi_energy2 - fermi_energy) > 1e-3:
replacement += f"fermi_energy2 = {fermi_energy2}\n"
if shift_fermi:
replacement += "p.axhline(y=0, color='cyan', linestyle='--', label='Fermi2', zorder=-1)\n"
else:
replacement += "p.axhline(y=fermi_energy2, color='cyan', linestyle='--', label='Fermi2', zorder=-1)\n"
else:
replacement += "fermi_energy2 = fermi_energy\n"
replacement += "pl.legend()\n\n"
replacement += "for path in paths:"
dft_mpl_code = dft_mpl_code.replace(b"for path in paths:", replacement.encode())
Expand All @@ -188,15 +197,15 @@ def get_mpl_code_for_bands(
)
wan_mpl_code = wan_mpl_code.replace(
b"p.plot(x, band, label=label,",
b"p.plot(x, [_-fermi_energy for _ in band], label=label,",
b"p.plot(x, [_-fermi_energy2 for _ in band], label=label,",
)
dft_mpl_code = dft_mpl_code.replace(
b"p.set_ylim([all_data['y_min_lim'], all_data['y_max_lim']])",
b"p.set_ylim([all_data['y_min_lim']-fermi_energy, all_data['y_max_lim']-fermi_energy])",
)
wan_mpl_code = wan_mpl_code.replace(
b"p.set_ylim([all_data['y_min_lim'], all_data['y_max_lim']])",
b"p.set_ylim([all_data['y_min_lim']-fermi_energy, all_data['y_max_lim']-fermi_energy])",
b"p.set_ylim([all_data['y_min_lim']-fermi_energy2, all_data['y_max_lim']-fermi_energy2])",
)

mpl_code = dft_mpl_code + wan_mpl_code
Expand Down Expand Up @@ -234,7 +243,12 @@ def get_output_bands(workchain):


def get_mpl_code_for_workchains(
workchain0, workchain1, title=None, save=False, filename=None
workchain0,
workchain1,
title=None,
save=False,
filename=None,
shift_fermi=False,
):
"""Return matplotlib code for comparing band structures of two workchains."""
# assume workchain0 is pw, workchain1 is wannier
Expand All @@ -252,23 +266,15 @@ def get_mpl_code_for_workchains(
if save and (filename is None):
filename = f"bandsdiff_{formula}_{workchain0.pk}_{workchain1.pk}.py"

if workchain1.process_class in (
Wannier90BaseWorkChain,
Wannier90BandsWorkChain,
Wannier90OptimizeWorkChain,
):
fermi_energy = get_workchain_fermi_energy(workchain1)
else:
if workchain0.process_class in [PwBandsWorkChain, ProjwfcBandsWorkChain]:
fermi_energy = workchain0.outputs["scf_parameters"]["fermi_energy"]
else:
raise ValueError(f"Cannot find fermi energy from {workchain0}")
fermi_energy = get_workchain_fermi_energy(workchain0)
fermi_energy2 = get_workchain_fermi_energy(workchain1)

mpl_code = get_mpl_code_for_bands(
dft_bands,
wan_bands,
fermi_energy=fermi_energy,
shift_fermi=False,
fermi_energy2=fermi_energy2,
shift_fermi=shift_fermi,
title=title,
save=save,
filename=filename,
Expand All @@ -278,7 +284,12 @@ def get_mpl_code_for_workchains(


def get_workchain_fermi_energy(
workchain: ty.Union[Wannier90BaseWorkChain, Wannier90BandsWorkChain]
workchain: ty.Union[
Wannier90BaseWorkChain,
Wannier90BandsWorkChain,
PwBandsWorkChain,
ProjwfcBandsWorkChain,
]
) -> float:
"""Get Fermi energy of Wannier90BandsWorkChain.
Expand Down

0 comments on commit acdaadf

Please sign in to comment.