diff --git a/src/aiida_wannier90_workflows/utils/workflows/plot/bands.py b/src/aiida_wannier90_workflows/utils/workflows/plot/bands.py index 738ddcb..f63ad68 100755 --- a/src/aiida_wannier90_workflows/utils/workflows/plot/bands.py +++ b/src/aiida_wannier90_workflows/utils/workflows/plot/bands.py @@ -140,6 +140,7 @@ def get_mpl_code_for_bands( wan_bands, *, fermi_energy=None, + fermi_energy2=None, shift_fermi=False, title=None, save=False, @@ -172,6 +173,14 @@ def get_mpl_code_for_bands( replacement += "p.axhline(y=0, color='blue', linestyle='--', label='Fermi', zorder=-1)\n" else: replacement += "p.axhline(y=fermi_energy, color='blue', linestyle='--', label='Fermi', zorder=-1)\n" + if (fermi_energy2 is not None) and abs(fermi_energy2 - fermi_energy) > 1e-3: + replacement += f"fermi_energy2 = {fermi_energy2}\n" + if shift_fermi: + replacement += "p.axhline(y=0, color='cyan', linestyle='--', label='Fermi2', zorder=-1)\n" + else: + replacement += "p.axhline(y=fermi_energy2, color='cyan', linestyle='--', label='Fermi2', zorder=-1)\n" + else: + replacement += "fermi_energy2 = fermi_energy\n" replacement += "pl.legend()\n\n" replacement += "for path in paths:" dft_mpl_code = dft_mpl_code.replace(b"for path in paths:", replacement.encode()) @@ -188,7 +197,7 @@ def get_mpl_code_for_bands( ) wan_mpl_code = wan_mpl_code.replace( b"p.plot(x, band, label=label,", - b"p.plot(x, [_-fermi_energy for _ in band], label=label,", + b"p.plot(x, [_-fermi_energy2 for _ in band], label=label,", ) dft_mpl_code = dft_mpl_code.replace( b"p.set_ylim([all_data['y_min_lim'], all_data['y_max_lim']])", @@ -196,7 +205,7 @@ def get_mpl_code_for_bands( ) wan_mpl_code = wan_mpl_code.replace( b"p.set_ylim([all_data['y_min_lim'], all_data['y_max_lim']])", - b"p.set_ylim([all_data['y_min_lim']-fermi_energy, all_data['y_max_lim']-fermi_energy])", + b"p.set_ylim([all_data['y_min_lim']-fermi_energy2, all_data['y_max_lim']-fermi_energy2])", ) mpl_code = dft_mpl_code + wan_mpl_code @@ -234,7 +243,12 @@ def get_output_bands(workchain): def get_mpl_code_for_workchains( - workchain0, workchain1, title=None, save=False, filename=None + workchain0, + workchain1, + title=None, + save=False, + filename=None, + shift_fermi=False, ): """Return matplotlib code for comparing band structures of two workchains.""" # assume workchain0 is pw, workchain1 is wannier @@ -252,23 +266,15 @@ def get_mpl_code_for_workchains( if save and (filename is None): filename = f"bandsdiff_{formula}_{workchain0.pk}_{workchain1.pk}.py" - if workchain1.process_class in ( - Wannier90BaseWorkChain, - Wannier90BandsWorkChain, - Wannier90OptimizeWorkChain, - ): - fermi_energy = get_workchain_fermi_energy(workchain1) - else: - if workchain0.process_class in [PwBandsWorkChain, ProjwfcBandsWorkChain]: - fermi_energy = workchain0.outputs["scf_parameters"]["fermi_energy"] - else: - raise ValueError(f"Cannot find fermi energy from {workchain0}") + fermi_energy = get_workchain_fermi_energy(workchain0) + fermi_energy2 = get_workchain_fermi_energy(workchain1) mpl_code = get_mpl_code_for_bands( dft_bands, wan_bands, fermi_energy=fermi_energy, - shift_fermi=False, + fermi_energy2=fermi_energy2, + shift_fermi=shift_fermi, title=title, save=save, filename=filename, @@ -278,7 +284,12 @@ def get_mpl_code_for_workchains( def get_workchain_fermi_energy( - workchain: ty.Union[Wannier90BaseWorkChain, Wannier90BandsWorkChain] + workchain: ty.Union[ + Wannier90BaseWorkChain, + Wannier90BandsWorkChain, + PwBandsWorkChain, + ProjwfcBandsWorkChain, + ] ) -> float: """Get Fermi energy of Wannier90BandsWorkChain.