Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Integration and test of RSS code #84

Merged
merged 40 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0e1bdbc
Merge branch 'JaGeo:main' into main
YuanbinLiu Jun 12, 2024
8a5745a
Update common jobs and utils, and add new RSS jobs
YuanbinLiu Jun 12, 2024
12d88b5
Revert "Merge branch 'JaGeo:main' into main"
YuanbinLiu Jun 12, 2024
df14cd9
Reapply "Merge branch 'JaGeo:main' into main"
YuanbinLiu Jun 12, 2024
406cc82
Revert "add YL to contributors"
YuanbinLiu Jun 12, 2024
c4b7144
added boltzhist code. WIP docstrings still need creating
MorrowChem Jun 13, 2024
9498e76
added docstrings, updated to numpy style
MorrowChem Jun 13, 2024
4718fb7
added airss installation to README
MorrowChem Jun 13, 2024
05ae5f7
removed airss installation from README, will be managed by conda
MorrowChem Jun 13, 2024
bcdf0a4
Merging the RSS code
YuanbinLiu Jun 21, 2024
c267c94
resolve conflict
YuanbinLiu Jul 24, 2024
f5efad2
Merge remote-tracking branch 'origin/main' into rss
YuanbinLiu Jul 24, 2024
11a7305
Resolved merge conflicts
YuanbinLiu Jul 24, 2024
d4fb6df
passed unit tests
YuanbinLiu Jul 29, 2024
baae101
add testing files
YuanbinLiu Jul 29, 2024
dac446b
merging rss code
YuanbinLiu Jul 29, 2024
4a69c6a
Remove redundant test files
YuanbinLiu Jul 29, 2024
a34fd0c
fix linting errors
YuanbinLiu Jul 30, 2024
0978476
adopt logging package
YuanbinLiu Jul 30, 2024
0efee39
fix lint errors
YuanbinLiu Jul 31, 2024
1e27fe3
fix conflicts between two workflows
YuanbinLiu Jul 31, 2024
874889a
minor revisions
YuanbinLiu Aug 1, 2024
4869604
add dgl version for installation
YuanbinLiu Aug 1, 2024
1002e1f
show installed packages
YuanbinLiu Aug 1, 2024
8371824
torch==2.2.1 for pytest
YuanbinLiu Aug 1, 2024
e04a09f
torch==2.2.1 for pytest
YuanbinLiu Aug 1, 2024
8f3c3be
torch==2.2.1 for pytest
YuanbinLiu Aug 1, 2024
cacc9b2
torchdata==0.7.1
YuanbinLiu Aug 1, 2024
daff41f
delete duplicate unit tests
YuanbinLiu Aug 1, 2024
193695c
set up installation of torchdata
YuanbinLiu Aug 1, 2024
5fdb278
modify regularization test
YuanbinLiu Aug 1, 2024
7ed1787
Add buildcell to path on github
YuanbinLiu Aug 1, 2024
67e538e
species_list is needed for the analysis plots for GAP
QuantumChemist Aug 2, 2024
091af95
species_list is needed for the analysis plots for GAP
QuantumChemist Aug 2, 2024
5d3d163
added comment
QuantumChemist Aug 2, 2024
26c609e
added checks for checking if sigma regularization is active
QuantumChemist Aug 2, 2024
26ca71b
add docstrings
QuantumChemist Aug 2, 2024
094d2d7
ignore airss
QuantumChemist Aug 2, 2024
4a09114
reduce the GAP unit test run time where accuracy isn't needed
QuantumChemist Aug 2, 2024
ba5f891
update
YuanbinLiu Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
616 changes: 615 additions & 1 deletion autoplex/data/common/jobs.py

Large diffs are not rendered by default.

625 changes: 625 additions & 0 deletions autoplex/data/common/utils.py

Large diffs are not rendered by default.

350 changes: 350 additions & 0 deletions autoplex/data/rss/flows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,350 @@
"""Flows for running RSS."""

from __future__ import annotations

from jobflow import Flow, Response, job

from autoplex.data.common.jobs import (
Data_preprocessing,
Sampling,
VASP_collect_data,
VASP_static,
)
from autoplex.data.rss.jobs import RandomizedStructure, do_rss
from autoplex.fitting.common.flows import MLIPFitMaker


@job
def initial_RSS(
struct_number: int = 10000,
tag: str = "GeSb2Te4",
selection_method: str = "cur",
num_of_selection: int = 3,
bcur_params: dict | None = None,
random_seed: int = None,
e0_spin: bool = False,
isolated_atom: bool = True,
dimer: bool = True,
dimer_range: list = None,
dimer_num: int = None,
custom_set: dict | None = None,
vasp_ref_file: str = "vasp_ref.extxyz",
rss_group: str = "initial",
test_ratio: float = 0.1,
regularization: bool = True,
distillation: bool = True,
f_max: float = 200,
pre_database_dir: str | None = None,
mlip_type: str = "GAP",
mlip_hyper: dict | None = None,
ref_energy_name: str = "REF_energy",
ref_force_name: str = "REF_forces",
ref_virial_name: str = "REF_virial",
num_processes_fit: int | None = None,
kt: float = None,
**fit_kwargs,
):
"""
Run initial Random Structure Searching (RSS) workflow from scratch.

The workflow consists of the following jobs:
job1 - RandomizedStructure: Generates randomized structures
job2 - Sampling: Samples a subset of the generated structures using CUR
job3 - VASP_static: Runs single-point calculations on the sampled structures
job4 - VASP_collect_data: Collects VASP calculation data
job5 - Data_preprocessing: Preprocesses the data for fitting ML models
job6 - MLIPFitMaker: Fits a ML interatomic potential (MLIP)

Parameters
----------
struct_number : int, optional
Number of structures to generate. Default is 10000.
tag : str, optional
Tag for the generated structures. Default is 'GeSb2Te4'.
selection_method : str, optional
Method for selecting structures. Default is 'cur'.
num_of_selection : int, optional
Number of structures to select. Default is 3.
bcur_params : str, optional
Parameters for the CUR method. Default is None.
random_seed : int, optional
Seed for random number generator. Default is None.
e0_spin : bool, optional
Whether to include spin polarization in the static calculations of isolated atoms and dimers. Default is False.
isolated_atom : bool, optional
Whether to include isolated atom calculations. Default is True.
dimer : bool, optional
Whether to include dimer calculations. Default is True.
dimer_range : list, optional
Distance range for dimer calculations. Default is None.
dimer_num : int, optional
Number of dimers generated for calculations. Default is None.
custom_set : dict, optional
Custom set of parameters for VASP. Default is None.
vasp_ref_file : str, optional
File name of collected VASP data. Default is 'vasp_ref.extxyz'.
rss_group : str, optional
Group name of structures for RSS. Default is 'initial'.
test_ratio : float, optional
The proportion of the test set after splitting the data. Default is 0.1.
regularization : bool, optional
Whether to apply regularization. This only works for GAP. Default is True.
distillation : bool, optional
Whether to apply distillation of structures. Default is True.
f_max : float, optional
Maximum force value to exclude structures. Default is 200.
pre_database_dir : str, optional
Directory for the preprocessed database. Default is None.
mlip_type : str, optional
Type of MLIP to fit. Default is 'GAP'.
mlip_hyper : str, optional
Hyperparameters for the MLIP. Default is None.
ref_energy_name : str, optional
Reference energy name. Default is "REF_energy".
ref_force_name : str, optional
Reference force name. Default is "REF_forces".
ref_virial_name : str, optional
Reference virial name. Default is "REF_virial".
num_processes_fit : int, optional
Number of processes for fitting. Default is None.
kt : float, optional
Value of kT. Default is None.
fit_kwargs : dict, optional
Additional arguments for the machine learning fit. Default is None.

Output
------
- test_error: float
The test error of the fitted MLIP.
- pre_database_dir: str
The directory of the preprocessed database.
- mlip_path: str
The path to the fitted MLIP.
- isol_es: dict
The isolated energy values.
- current_iter: int
The current iteration index, set to 0.
- kt: float
The value of kT.
"""
job1 = RandomizedStructure(struct_number=struct_number, tag=tag).make()
job2 = Sampling(
selection_method=selection_method,
num_of_selection=num_of_selection,
bcur_params=bcur_params,
dir=job1.output,
random_seed=random_seed,
)
job3 = VASP_static(
structures=job2.output,
e0_spin=e0_spin,
isolated_atom=isolated_atom,
dimer=dimer,
dimer_range=dimer_range,
dimer_num=dimer_num,
custom_set=custom_set,
)
job4 = VASP_collect_data(
vasp_ref_file=vasp_ref_file, rss_group=rss_group, vasp_dirs=job3.output
)
job5 = Data_preprocessing(
test_ratio=test_ratio,
regularization=regularization,
distillation=distillation,
f_max=f_max,
vasp_ref_dir=job4.output["vasp_ref_dir"],
pre_database_dir=pre_database_dir,
)
job6 = MLIPFitMaker(
mlip_type=mlip_type,
mlip_hyper=mlip_hyper,
ref_energy_name=ref_energy_name,
ref_force_name=ref_force_name,
ref_virial_name=ref_virial_name,
).make(
database_dir=job5.output,
isol_es=job4.output["isol_es"],
num_processes_fit=num_processes_fit,
preprocessing_data=False,
**fit_kwargs,
)

job_list = [job1, job2, job3, job4, job5, job6]

return Response(
replace=Flow(job_list),
output={
"test_error": job6.output["test_error"],
"pre_database_dir": job5.output,
"mlip_path": job6.output["mlip_path"],
"isol_es": job4.output["isol_es"],
"current_iter": 0,
"kt": 0.6,
},
)


@job
def do_RSS_iterations(
input: dict | None = None,
struct_number: int = 10000,
tag: str = "GeSb2Te4",
selection_method1: str = "cur",
selection_method2: str = "bcur",
num_of_selection1: int = 3,
num_of_selection2: int = 5,
bcur_params: dict | None = None,
random_seed: int = None,
e0_spin: bool = False,
isolated_atom: bool = True,
dimer: bool = True,
dimer_range: list = None,
dimer_num: int = None,
custom_set: dict | None = None,
vasp_ref_file: str = "vasp_ref.extxyz",
rss_group: str = "initial",
test_ratio: float = 0.1,
regularization: bool = True,
distillation: bool = True,
f_max: float = 200,
mlip_type: str = "GAP",
mlip_hyper: dict | None = None,
ref_energy_name: str = "REF_energy",
ref_force_name: str = "REF_forces",
ref_virial_name: str = "REF_virial",
num_processes_fit: int = None,
scalar_pressure_method: str = "exp",
scalar_exp_pressure: float = 100,
scalar_pressure_exponential_width: float = 0.2,
scalar_pressure_low: float = 0,
scalar_pressure_high: float = 50,
max_steps: int = 10,
force_tol: float = 0.1,
stress_tol: float = 0.1,
Hookean_repul: bool = False,
write_traj: bool = True,
num_processes_rss: int = 4,
device: str = "cpu",
stop_criterion: float = 0.01,
max_iteration_number: int = 9,
**fit_kwargs,
):
"""
Perform iterative RSS to improve the accuracy of a MLIP.

Each iteration involves generating new structures, sampling, running
VASP calculations, collecting data, preprocessing data, and fitting a new MLIP.
"""
if input is None:
input = {
"test_error": None,
"pre_database_dir": None,
"mlip_path": None,
"isol_es": None,
"current_iter": 0,
"kt": 0.6,
}

test_error = input.get("test_error")
current_iter = input.get("current_iter")

if (
test_error is not None
and test_error > stop_criterion
and current_iter is not None
and current_iter < max_iteration_number
):
kt = input["kt"] - 0.1 if input["kt"] > 0.15 else 0.1
print("kt:", kt)
current_iter += 1
print("Current iter index:", current_iter)
print(f"The error of {current_iter}th iteration:", test_error)

if bcur_params is None:
bcur_params = {}
bcur_params["kT"] = kt

job1 = RandomizedStructure(struct_number=struct_number, tag=tag).make()
job2 = Sampling(
selection_method=selection_method1,
num_of_selection=num_of_selection1,
bcur_params=bcur_params,
dir=job1.output,
random_seed=random_seed,
)
job3 = do_rss(
mlip_type=mlip_type,
iteration_index=f"{current_iter}th",
mlip_path=input["mlip_path"],
structure=job2.output,
scalar_pressure_method=scalar_pressure_method,
scalar_exp_pressure=scalar_exp_pressure,
scalar_pressure_exponential_width=scalar_pressure_exponential_width,
scalar_pressure_low=scalar_pressure_low,
scalar_pressure_high=scalar_pressure_high,
max_steps=max_steps,
force_tol=force_tol,
stress_tol=stress_tol,
Hookean_repul=Hookean_repul,
write_traj=write_traj,
num_processes_rss=num_processes_rss,
device=device,
)
job4 = Sampling(
selection_method=selection_method2,
num_of_selection=num_of_selection2,
bcur_params=bcur_params,
traj_info=job3.output,
random_seed=random_seed,
isol_es=input["isol_es"],
)
job5 = VASP_static(
structures=job4.output,
e0_spin=e0_spin,
isolated_atom=isolated_atom,
dimer=dimer,
dimer_range=dimer_range,
dimer_num=dimer_num,
custom_set=custom_set,
)
job6 = VASP_collect_data(
vasp_ref_file=vasp_ref_file, rss_group=rss_group, vasp_dirs=job5.output
)
job7 = Data_preprocessing(
test_ratio=test_ratio,
regularization=regularization,
distillation=distillation,
f_max=f_max,
vasp_ref_dir=job6.output["vasp_ref_dir"],
pre_database_dir=input["pre_database_dir"],
)
job8 = MLIPFitMaker(
mlip_type=mlip_type,
mlip_hyper=mlip_hyper,
ref_energy_name=ref_energy_name,
ref_force_name=ref_force_name,
ref_virial_name=ref_virial_name,
).make(
database_dir=job7.output,
isol_es=input["isol_es"],
num_processes_fit=num_processes_fit,
preprocessing_data=False,
**fit_kwargs,
)

job9 = do_RSS_iterations(
input={
"test_error": job8.output["test_error"],
"pre_database_dir": job7.output,
"mlip_path": job8.output["mlip_path"],
"isol_es": input["isol_es"],
"current_iter": current_iter,
"kt": kt,
},
)

job_list = [job1, job2, job3, job4, job5, job6, job7, job8, job9]

return Response(detour=job_list, output=job9.output)

return Response(output=input)
Loading