Skip to content

Commit f50f0a3

Browse files
committed
Further improvement on model facility functions
1 parent b9eb06f commit f50f0a3

File tree

5 files changed

+136
-71
lines changed

5 files changed

+136
-71
lines changed

bench/model_trainer.cxx

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
/** Copyright (c) 2011, Edgar Solomonik, all rights reserved.
22
* \addtogroup benchmarks
3-
* @{
3+
* @{
44
* \addtogroup model_trainer
5-
* @{
5+
* @{
66
* \brief Executes a set of different contractions on different processor counts to train model parameters
77
*/
88

@@ -47,9 +47,9 @@ void train_dns_vec_mat(int64_t n, int64_t m, World & dw){
4747
Function<> f1([](double a){ return a*a; });
4848

4949
A2["ij"] = f1(A["ij"]);
50-
50+
5151
c["i"] += f1(A["ij"]);
52-
52+
5353
Function<> f2([](double a, double b){ return a*a+b*b; });
5454

5555
A1["ij"] -= f2(A["kj"], F["ki"]);
@@ -60,7 +60,7 @@ void train_dns_vec_mat(int64_t n, int64_t m, World & dw){
6060
t1(A["ij"]);
6161

6262
Transform<> t2([](double a, double & b){ b-=b/a; });
63-
63+
6464
t2(b["i"],b["i"]);
6565
t2(A["ij"],A2["ij"]);
6666

@@ -81,7 +81,7 @@ void train_sps_vec_mat(int64_t n, int64_t m, World & dw, bool sp_A, bool sp_B, b
8181
Matrix<> A2(m, n, dw);
8282
Matrix<> G(n, n, NS, dw);
8383
Matrix<> F(m, m, NS, dw);
84-
84+
8585
srand48(dw.rank);
8686
b.fill_random(-.5, .5);
8787
c.fill_random(-.5, .5);
@@ -102,36 +102,36 @@ void train_sps_vec_mat(int64_t n, int64_t m, World & dw, bool sp_A, bool sp_B, b
102102
B.sparsify([=](double a){ return fabs(a)<=.5*sp; });
103103
c.sparsify([=](double a){ return fabs(a)<=.5*sp; });
104104
}
105-
105+
106106
B["ij"] += A["ik"]*G["kj"];
107107
if (!sp_C) B["ij"] += A["ij"]*A1["ij"];
108108
B["ij"] += F["ik"]*A["kj"];
109109
c["i"] += A["ij"]*b["j"];
110110
b["j"] += .2*A["ij"]*c["i"];
111111
if (!sp_C) b["i"] += b["i"]*b["i"];
112-
112+
113113
Function<> f1([](double a){ return a*a; });
114-
114+
115115
A2["ij"] = f1(A["ij"]);
116-
116+
117117
c["i"] += f1(A["ij"]);
118-
118+
119119
Function<> f2([](double a, double b){ return a*a+b*b; });
120-
120+
121121
A2["ji"] -= f2(A1["ki"], F["kj"]);
122-
122+
123123
Transform<> t1([](double & a){ a*=a; });
124-
124+
125125
t1(b["i"]);
126126
t1(A["ij"]);
127-
127+
128128
Transform<> t2([](double a, double & b){ b-=b/a; });
129-
129+
130130
t2(b["i"],b["i"]);
131131
t2(A["ij"],A2["ij"]);
132-
132+
133133
/*Transform<> t3([](double a, double b, double & c){ c=c*c-b*a; });
134-
134+
135135
t3(c["i"],b["i"],b["i"]);
136136
t3(A["ij"],G["ij"],F["ij"]);*/
137137
}
@@ -254,6 +254,17 @@ int main(int argc, char ** argv){
254254
if (time < 0) time = 5.0;
255255
} else time = 5.0;
256256

257+
if(std::find(input_str, input_str+in_num,"-load")){
258+
CTF_int::load_all_models("./src/shared/model_coeff_record");
259+
}
260+
261+
if(std::find(input_str, input_str+in_num,"-write")){
262+
CTF_int::write_all_models("./src/shared/model_coeff_record");
263+
}
264+
265+
if(std::find(input_str, input_str+in_num,"-dump")){
266+
CTF_int::dump_all_models("./src/shared/data");
267+
}
257268

258269
{
259270
World dw(MPI_COMM_WORLD, argc, argv);
@@ -270,7 +281,6 @@ int main(int argc, char ** argv){
270281
}
271282

272283
/**
273-
* @}
284+
* @}
274285
* @}
275286
*/
276-

model_trainer

Whitespace-only changes.

src/shared/model.cxx

Lines changed: 85 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,30 @@ namespace CTF_int {
3030
#endif
3131
}
3232

33+
void load_all_models(std::string file_name){
34+
#ifdef TUNE
35+
for (int i=0; i<get_all_models().size(); i++){
36+
get_all_models()[i]->load_coeff(file_name);
37+
}
38+
#endif
39+
}
40+
41+
void write_all_models(std::string file_name){
42+
#ifdef TUNE
43+
for (int i=0; i<get_all_models().size(); i++){
44+
get_all_models()[i]->write_coeff(file_name);
45+
}
46+
#endif
47+
}
48+
49+
void dump_all_models(std::string path){
50+
#ifdef TUNE
51+
for (int i=0; i<get_all_models().size(); i++){
52+
get_all_models()[i]->dump_data(path);
53+
}
54+
#endif
55+
}
56+
3357

3458
#define SPLINE_CHUNK_SZ = 8
3559

@@ -463,26 +487,26 @@ namespace CTF_int {
463487
}
464488

465489
template <int nparam>
466-
void LinModel<nparam>::write_coeff(){
490+
void LinModel<nparam>::write_coeff(std::string file_name){
467491

468492
// Generate the model name
469-
std::string model_name = std::string(name)+"_init[]";
493+
std::string model_name = std::string(name);
470494
// Generate the new line in the file
471-
std::string new_coeff_str = "double " + model_name+ " = {";
495+
std::string new_coeff_str = model_name+" ";
472496
char buffer[64];
473497
for(int i =0; i<nparam; i++){
474498
buffer[0] = '\0';
475499
std::sprintf(buffer,"%1.4E", coeff_guess[i]);
476500
std::string s(buffer);
477501
new_coeff_str += s;
478502
if (i != nparam - 1){
479-
new_coeff_str += ", ";
503+
new_coeff_str += " ";
480504
}
481505
}
482-
new_coeff_str += "};";
506+
483507
// Open the file that stores the model info
484508
std::vector<std::string> file_content;
485-
std::ifstream infile("model_init.cxx");
509+
std::ifstream infile(file_name);
486510
if(!infile){
487511
std::cout<<"Error opening file"<<std::endl;
488512
return;
@@ -496,7 +520,7 @@ namespace CTF_int {
496520
// Get the model name from the line
497521
std::string s;
498522
std::getline(f,s,' ');
499-
std::getline(f,s,' ');
523+
std::cout<<s<<", "<<model_name<<std::endl;
500524
if (s == model_name){
501525
line = new_coeff_str;
502526
found_line = true;
@@ -505,82 +529,93 @@ namespace CTF_int {
505529
file_content.push_back(line);
506530
}
507531

532+
// Append the string to the file if no match is found
533+
if(!found_line){
534+
new_coeff_str += "\n";
535+
file_content.push_back(new_coeff_str);
536+
}
508537
std::ofstream ofs;
509-
ofs.open("model_init.cxx", std::ofstream::out | std::ofstream::trunc);
538+
ofs.open(file_name, std::ofstream::out | std::ofstream::trunc);
510539
for(int i=0; i<file_content.size(); i++){
511540
ofs<<file_content[i];
512541
}
513542
ofs.close();
514-
515-
// Check if the file is successfully updated
516-
if(!found_line){
517-
std::cout<<"Error! No model declared in model_init.cxx and model_init.h. Please declare the model first!"<<std::endl;
518-
}
519543
}
520544

521545

546+
522547
template <int nparam>
523-
void LinModel<nparam>::load_coeff(){
548+
void LinModel<nparam>::load_coeff(std::string file_name){
524549
// Generate the model name
525-
std::string model_name = std::string(name)+"_init[]";
550+
std::string model_name = std::string(name);
526551

527552
// Open the file that stores the model info
528553
std::vector<std::string> file_content;
529-
std::ifstream infile("init_models.cxx");
554+
std::ifstream infile(file_name);
530555
if(!infile){
531556
std::cout<<"Error opening file"<<std::endl;
532557
return;
533558
}
534559

535-
// Scan the file to find the line and replace with the new model coeffs
536-
std::string line;
560+
// Flag boolean denotes whether the model is found in the file
537561
bool found_line = false;
538-
bool succeed = false;
562+
// Flag boolean denotes whether the number of coefficients in the file matches with what the model expects
563+
bool right_num_coeff = true;
564+
565+
// Scan the file to find the model coefficients
566+
std::string line;
539567
while(std::getline(infile,line)){
540568
std::istringstream f(line);
541569
// Get the model name from the line
542570
std::string s;
543571
std::getline(f,s,' ');
544-
std::getline(f,s,' ');
545572
if (s == model_name){
546573
found_line = true;
547-
// Get rid of the '='
548-
std::getline(f,s,' ');
549-
// Get the n coeffs
574+
575+
// Get the nparam coeffs
576+
double coeff_from_file [nparam];
550577
for(int i=0; i<nparam; i++){
551578
if(!std::getline(f,s,' ')){
552-
break;
579+
right_num_coeff = false;
580+
break;
553581
}
554-
char buffer[64];
555-
int index = 0;
556-
for(int i = 0; i<s.size(); i++){
557-
if(s[i] != '{' && s[i] != '}' && s[i] != ',' && s[i] != ';'){
558-
buffer[index] = s[i];
559-
index++;
560-
}
582+
583+
// Convert the string to char* and update the model coefficients
584+
char buf[s.length()+1];
585+
for(int i=0;i<s.length();i++){
586+
buf[i] = s[i];
561587
}
562-
buffer[index] = '\0';
563-
coeff_guess[i] = std::atof(buffer);
588+
buf[s.length()] = '\0';
589+
coeff_guess[i] = std::atof(buf);
590+
}
591+
// Check if there are more coefficients in the file
592+
if(right_num_coeff && std::getline(f,s,' ')){
593+
right_num_coeff = false;
564594
}
565-
succeed = true;
566595
break;
567596
}
568597
}
569598
// If the model is not found
570599
if(!found_line){
571-
std::cout<<"Error! Not model found in the file!"<<std::endl;
600+
std::cout<<"Error! No model found in the file!"<<std::endl;
572601
}
573-
// If there is not enough parameters
574-
if(!succeed){
575-
std::cout<<"Error! Not enough number of paramters in file!"<<std::endl;
602+
else if (!right_num_coeff){
603+
std::cout<<"Error! Number of coefficients in file does not match with the model"<<std::endl;
604+
// Initialize model coeff to be all 0s
605+
for(int i = 0; i < nparam;i++){
606+
coeff_guess[i] = 0.0;
607+
}
576608
}
577609
}
578610

611+
579612
template <int nparam>
580-
void LinModel<nparam>::dump_data(std::string file_name){
613+
void LinModel<nparam>::dump_data(std::string path){
581614
// Open the file
615+
std::string model_name = std::string(name);
582616
std::ofstream ofs;
583-
ofs.open("./data/"+file_name, std::ofstream::out | std::ofstream::trunc);
617+
ofs.open("path/"+model_name, std::ofstream::out | std::ofstream::trunc);
618+
584619
// Dump the model coeffs
585620
for(int i=0; i<nparam; i++){
586621
ofs<<coeff_guess[i]<<" ";
@@ -682,14 +717,20 @@ namespace CTF_int {
682717
}
683718

684719
template <int nparam>
685-
void CubicModel<nparam>::load_coeff(){
686-
lmdl.load_coeff();
720+
void CubicModel<nparam>::load_coeff(std::string file_name){
721+
lmdl.load_coeff(file_name);
687722
}
688723

689724
template <int nparam>
690-
void CubicModel<nparam>::write_coeff(){
691-
lmdl.write_coeff();
725+
void CubicModel<nparam>::write_coeff(std::string file_name){
726+
lmdl.write_coeff(file_name);
692727
}
728+
729+
template <int nparam>
730+
void CubicModel<nparam>::dump_data(std::string path){
731+
lmdl.dump_data(path);
732+
}
733+
693734
template class CubicModel<1>;
694735
template class CubicModel<2>;
695736
template class CubicModel<3>;

0 commit comments

Comments
 (0)