@@ -23,12 +23,18 @@ using executorch::extension::training::optimizer::SGDOptions;
23
23
using executorch::runtime::Error;
24
24
using executorch::runtime::Result;
25
25
DEFINE_string (model_path, " xor.pte" , " Model serialized in flatbuffer format." );
26
+ DEFINE_string (ptd_path, " " , " Model weights serialized in flatbuffer format." );
26
27
27
28
int main (int argc, char ** argv) {
28
29
gflags::ParseCommandLineFlags (&argc, &argv, true );
29
- if (argc != 1 ) {
30
+ if (argc == 0 ) {
31
+ ET_LOG (Error, " Please provide a model path." );
32
+ return 1 ;
33
+ } else if (argc > 2 ) {
30
34
std::string msg = " Extra commandline args: " ;
31
- for (int i = 1 /* skip argv[0] (program name) */ ; i < argc; i++) {
35
+ for (int i = 2 /* skip argv[0] (pte path) and argv[1] (ptd path) */ ;
36
+ i < argc;
37
+ i++) {
32
38
msg += argv[i];
33
39
}
34
40
ET_LOG (Error, " %s" , msg.c_str ());
@@ -46,7 +52,21 @@ int main(int argc, char** argv) {
46
52
auto loader = std::make_unique<executorch::extension::FileDataLoader>(
47
53
std::move (loader_res.get ()));
48
54
49
- auto mod = executorch::extension::training::TrainingModule (std::move (loader));
55
+ std::unique_ptr<executorch::extension::FileDataLoader> ptd_loader = nullptr ;
56
+ if (!FLAGS_ptd_path.empty ()) {
57
+ executorch::runtime::Result<executorch::extension::FileDataLoader>
58
+ ptd_loader_res =
59
+ executorch::extension::FileDataLoader::from (FLAGS_ptd_path.c_str ());
60
+ if (ptd_loader_res.error () != Error::Ok) {
61
+ ET_LOG (Error, " Failed to open ptd file: %s" , FLAGS_ptd_path.c_str ());
62
+ return 1 ;
63
+ }
64
+ ptd_loader = std::make_unique<executorch::extension::FileDataLoader>(
65
+ std::move (ptd_loader_res.get ()));
66
+ }
67
+
68
+ auto mod = executorch::extension::training::TrainingModule (
69
+ std::move (loader), nullptr , nullptr , nullptr , std::move (ptd_loader));
50
70
51
71
// Create full data set of input and labels.
52
72
std::vector<std::pair<
@@ -70,7 +90,10 @@ int main(int argc, char** argv) {
70
90
// Get the params and names
71
91
auto param_res = mod.named_parameters (" forward" );
72
92
if (param_res.error () != Error::Ok) {
73
- ET_LOG (Error, " Failed to get named parameters" );
93
+ ET_LOG (
94
+ Error,
95
+ " Failed to get named parameters, error: %d" ,
96
+ static_cast <int >(param_res.error ()));
74
97
return 1 ;
75
98
}
76
99
@@ -112,5 +135,6 @@ int main(int argc, char** argv) {
112
135
std::string (param.first .data ()), param.second });
113
136
}
114
137
115
- executorch::extension::flat_tensor::save_ptd (" xor.ptd" , param_map, 16 );
138
+ executorch::extension::flat_tensor::save_ptd (
139
+ " trained_xor.ptd" , param_map, 16 );
116
140
}
0 commit comments