Skip to content

Commit 3e951fc

Browse files
committed
performance_model update
1 parent f50f0a3 commit 3e951fc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2095
-56
lines changed

bench/model_trainer.cxx

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ void frize(std::set<int> & ps, int p){
208208
}
209209
}
210210

211-
void train_all(double time, World & dw){
211+
void train_all(double time, World & dw, bool write_coeff, bool dump_data, std::string coeff_file, std::string data_dir){
212212
std::set<int> ps;
213213
frize(ps, dw.np);
214214

@@ -226,6 +226,12 @@ void train_all(double time, World & dw){
226226
train_world(dtime, w);
227227
CTF_int::update_all_models(w.cdt.cm);
228228
}
229+
if(write_coeff)
230+
CTF_int::write_all_models(coeff_file);
231+
if(dump_data){
232+
CTF_int::dump_all_models(data_dir);
233+
}
234+
229235
}
230236

231237
char* getCmdOption(char ** begin,
@@ -254,16 +260,42 @@ int main(int argc, char ** argv){
254260
if (time < 0) time = 5.0;
255261
} else time = 5.0;
256262

257-
if(std::find(input_str, input_str+in_num,"-load")){
258-
CTF_int::load_all_models("./src/shared/model_coeff_record");
263+
// Get the environment variable FILE_PATH
264+
char * file_path = getenv("FILE_PATH");
265+
std::string coeff_file;
266+
267+
if(!file_path){
268+
// If the enviroment variable is not defined, use the default path
269+
coeff_file = std::string("../src/shared/model_coeff_record");
270+
}else{
271+
// Else, use the file path specified by the environment variable
272+
coeff_file = std::string(file_path);
273+
}
274+
275+
// If the user specifies -load, read the model coefficients from the file specified by the FILE_PATH environment variable
276+
if(std::find(input_str, input_str+in_num, std::string("-load")) != input_str + in_num){
277+
CTF_int::load_all_models(coeff_file);
278+
}
279+
280+
// Boolean expression that are used to pass command line argument to function train_all
281+
bool write_coeff = false;
282+
bool dump_data = false;
283+
284+
if(std::find(input_str, input_str+in_num, std::string("-write")) != input_str + in_num){
285+
write_coeff = true;
259286
}
260287

261-
if(std::find(input_str, input_str+in_num,"-write")){
262-
CTF_int::write_all_models("./src/shared/model_coeff_record");
288+
char * data_dir = getenv("MODEL_DATA_DIR");
289+
std::string data_dir_str;
290+
if(!data_dir){
291+
data_dir_str = std::string("../src/shared/data");
263292
}
293+
else{
294+
data_dir_str = std::string(data_dir);
295+
}
264296

265-
if(std::find(input_str, input_str+in_num,"-dump")){
266-
CTF_int::dump_all_models("./src/shared/data");
297+
if(std::find(input_str, input_str+in_num, std::string("-dump")) != input_str + in_num){
298+
dump_data = true;
267299
}
268300

269301
{
@@ -272,7 +304,7 @@ int main(int argc, char ** argv){
272304
if (rank == 0){
273305
printf("Executing a wide set of contractions to train model with time budget of %lf sec\n", time);
274306
}
275-
train_all(time, dw);
307+
train_all(time, dw, write_coeff, dump_data, coeff_file, data_dir_str);
276308
}
277309

278310

0 commit comments

Comments
 (0)