Skip to content

Commit 3161c19

Browse files
authored
Corrected finetuning of NN via residual
1 parent efda7f5 commit 3161c19

File tree

1 file changed

+35
-23
lines changed

1 file changed

+35
-23
lines changed

rom_application/RomManager_cantilever_NN_residual/run_rom_manager.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def UpdateProjectParameters(parameters, mu=None):
4141
"""
4242
Customize ProjectParameters here for imposing different conditions to the simulations as needed
4343
"""
44-
steps = 300
44+
steps = 400
4545
parameters["processes"]["loads_process_list"][0]["Parameters"]["modulus"].SetString(str(mu[0]/steps)+"*t")
4646
parameters["processes"]["loads_process_list"][1]["Parameters"]["modulus"].SetString(str(mu[1]/steps)+"*t")
4747
parameters["problem_data"]["end_time"].SetDouble(steps)
@@ -76,7 +76,7 @@ def GetRomManagerParameters():
7676
"rom_basis_output_format": "numpy",
7777
"rom_basis_output_name": "RomParameters",
7878
"snapshots_control_type": "time", // "step", "time"
79-
"snapshots_interval": 300,
79+
"snapshots_interval": 400,
8080
"snapshots_control_is_periodic": false,
8181
"print_singular_values": true,
8282
"galerkin_rom_bns_settings": {
@@ -94,7 +94,7 @@ def GetRomManagerParameters():
9494
"modes":[14,60],
9595
"layers_size":[200,200],
9696
"batch_size":16,
97-
"epochs":800,
97+
"epochs":6,
9898
"NN_gradient_regularisation_weight": 0.0,
9999
"lr_strategy":{
100100
"scheduler": "sgdr",
@@ -134,33 +134,45 @@ def get_multiple_params(num_of_samples, seed):
134134

135135
mu_train = np.array(get_multiple_params(500, 824)).tolist()
136136
mu_validation = np.array(get_multiple_params(100, 235)).tolist()
137-
mu_test = np.array(get_multiple_params(100, 539)).tolist()
137+
mu_test = np.array(get_multiple_params(10, 539)).tolist()
138+
# mu_test = np.array(get_multiple_params(100, 539)).tolist()
138139

139140
general_rom_manager_parameters = GetRomManagerParameters()
140141
project_parameters_name = "datasets_rubber_hyperelastic_cantilever_big_range/ProjectParameters_FOM.json"
141142

142143
rom_manager = RomManager(project_parameters_name,general_rom_manager_parameters,CustomizeSimulation,UpdateProjectParameters, UpdateMaterialParametersFile)
143144

144-
# First, train the snapshot-based model
145+
# First, train the ANN-PROM model via snapshot-based loss
145146
rom_manager.Fit(mu_train=mu_train, mu_validation=mu_validation)
146-
147-
# rom_manager.Test(mu_test=mu_test, mu_train=mu_train, start_from_closest_mu=False, filter_nan=True)
148-
# rom_manager.PrintErrors()
149-
# rom_manager.TestNeuralNetworkReconstruction(mu_train, mu_validation, mu_test)
150-
# snapshots_matrix = rom_manager.GenerateOrderedFOMSnapshotsMatrix(mu_train)
151-
# print(snapshots_matrix.shape)
152-
# np.save('mu_train.npy', mu_train)
153-
# np.save('snapshots_train.npy', snapshots_matrix.T)
154-
155-
156-
# mu_run_raw = np.array(get_multiple_params(1, 97))
157-
# mu_run = mu_run_raw
158-
# mu_run = mu_run.tolist()
159-
# print(mu_run)
160-
# rom_manager.Test(mu_test=mu_run, mu_train=mu_train, start_from_closest_mu=False, filter_nan=True)
161-
# rom_manager.PrintErrors()
162-
# rom_manager.RunFOM(mu_run)
163-
# rom_manager.RunROM(mu_run, mu_train)
164147

148+
# Then, train the ANN-PROM model via snapshot-based loss on top of one we just trained.
149+
in_database, sloss_model_name = rom_manager.data_base.check_if_in_database("Neural_Network", mu_train)
150+
151+
assert in_database
152+
print(sloss_model_name)
153+
154+
# Update the ROM Manager parameters to include the recently-trained ANN's directory
155+
pretrained_model_path = str(rom_manager.data_base.database_root_directory)+'/saved_nn_models/' + sloss_model_name
156+
rom_manager.general_rom_manager_parameters["ROM"]["ann_enhanced_settings"]["online"]["custom_model_path"].SetString(pretrained_model_path)
157+
rom_manager.general_rom_manager_parameters["ROM"]["ann_enhanced_settings"]["lr_strategy"]["base_lr"].SetDouble(0.0001)
158+
print(rom_manager.general_rom_manager_parameters)
159+
160+
# Finetune the NN using the residual loss
161+
rom_manager.FinetuneANNOnResidual(mu_train, mu_validation)
162+
163+
# Test accuracy trained on snapshot
164+
rom_manager.general_rom_manager_parameters["ROM"]["ann_enhanced_settings"]["lr_strategy"]["base_lr"].SetDouble(0.001)
165+
rom_manager.Test(mu_test=mu_test, mu_train=mu_train, filter_nan=True)
166+
rom_manager.PrintErrors()
167+
168+
# Test accuracy finetuned on residual
169+
in_database, rloss_model_name = rom_manager.data_base.check_if_in_database("Neural_Network_Residual", mu_train)
170+
assert in_database
171+
print(rloss_model_name)
172+
finetuned_model_path = str(rom_manager.data_base.database_root_directory)+'/saved_nn_models_residual/' + rloss_model_name
173+
rom_manager.general_rom_manager_parameters["ROM"]["ann_enhanced_settings"]["online"]["custom_model_path"].SetString(finetuned_model_path)
174+
rom_manager.general_rom_manager_parameters["ROM"]["ann_enhanced_settings"]["lr_strategy"]["base_lr"].SetDouble(0.0001)
175+
rom_manager.Test(mu_test=mu_test, mu_train=mu_train, filter_nan=True)
176+
rom_manager.PrintErrors()
165177

166178

0 commit comments

Comments
 (0)