@@ -208,7 +208,7 @@ void frize(std::set<int> & ps, int p){
208
208
}
209
209
}
210
210
211
- void train_all (double time, World & dw){
211
+ void train_all (double time, World & dw, bool write_coeff, bool dump_data, std::string coeff_file, std::string data_dir ){
212
212
std::set<int > ps;
213
213
frize (ps, dw.np );
214
214
@@ -226,6 +226,12 @@ void train_all(double time, World & dw){
226
226
train_world (dtime, w);
227
227
CTF_int::update_all_models (w.cdt .cm );
228
228
}
229
+ if (write_coeff)
230
+ CTF_int::write_all_models (coeff_file);
231
+ if (dump_data){
232
+ CTF_int::dump_all_models (data_dir);
233
+ }
234
+
229
235
}
230
236
231
237
char * getCmdOption (char ** begin,
@@ -254,16 +260,42 @@ int main(int argc, char ** argv){
254
260
if (time < 0 ) time = 5.0 ;
255
261
} else time = 5.0 ;
256
262
257
- if (std::find (input_str, input_str+in_num," -load" )){
258
- CTF_int::load_all_models (" ./src/shared/model_coeff_record" );
263
+ // Get the environment variable FILE_PATH
264
+ char * file_path = getenv (" FILE_PATH" );
265
+ std::string coeff_file;
266
+
267
+ if (!file_path){
268
+ // If the enviroment variable is not defined, use the default path
269
+ coeff_file = std::string (" ../src/shared/model_coeff_record" );
270
+ }else {
271
+ // Else, use the file path specified by the environment variable
272
+ coeff_file = std::string (file_path);
273
+ }
274
+
275
+ // If the user specifies -load, read the model coefficients from the file specified by the FILE_PATH environment variable
276
+ if (std::find (input_str, input_str+in_num, std::string (" -load" )) != input_str + in_num){
277
+ CTF_int::load_all_models (coeff_file);
278
+ }
279
+
280
+ // Boolean expression that are used to pass command line argument to function train_all
281
+ bool write_coeff = false ;
282
+ bool dump_data = false ;
283
+
284
+ if (std::find (input_str, input_str+in_num, std::string (" -write" )) != input_str + in_num){
285
+ write_coeff = true ;
259
286
}
260
287
261
- if (std::find (input_str, input_str+in_num," -write" )){
262
- CTF_int::write_all_models (" ./src/shared/model_coeff_record" );
288
+ char * data_dir = getenv (" MODEL_DATA_DIR" );
289
+ std::string data_dir_str;
290
+ if (!data_dir){
291
+ data_dir_str = std::string (" ../src/shared/data" );
263
292
}
293
+ else {
294
+ data_dir_str = std::string (data_dir);
295
+ }
264
296
265
- if (std::find (input_str, input_str+in_num," -dump" )){
266
- CTF_int::dump_all_models ( " ./src/shared/data " ) ;
297
+ if (std::find (input_str, input_str+in_num, std::string ( " -dump" )) != input_str + in_num ){
298
+ dump_data = true ;
267
299
}
268
300
269
301
{
@@ -272,7 +304,7 @@ int main(int argc, char ** argv){
272
304
if (rank == 0 ){
273
305
printf (" Executing a wide set of contractions to train model with time budget of %lf sec\n " , time);
274
306
}
275
- train_all (time, dw);
307
+ train_all (time, dw, write_coeff, dump_data, coeff_file, data_dir_str );
276
308
}
277
309
278
310
0 commit comments