Skip to content

Update Demo Scripts To Use .ptd (retry) #8886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,15 @@ cmake_dependent_option(
"NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF
)

if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR)
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON)
set(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR ON)
set(EXECUTORCH_BUILD_EXTENSION_MODULE ON)
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
endif()

if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
set(EXECUTORCH_BUILD_EXTENSION_TENSOR ON)
if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR)
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON)
set(EXECUTORCH_BUILD_EXTENSION_MODULE ON)
endif()

if(EXECUTORCH_BUILD_EXTENSION_MODULE)
Expand Down
2 changes: 1 addition & 1 deletion extension/training/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ target_include_directories(
target_include_directories(extension_training PUBLIC ${EXECUTORCH_ROOT}/..)
target_compile_options(extension_training PUBLIC ${_common_compile_options})
target_link_libraries(extension_training executorch_core
extension_data_loader extension_module extension_tensor)
extension_data_loader extension_module extension_tensor extension_flat_tensor)


list(TRANSFORM _train_xor__srcs PREPEND "${EXECUTORCH_ROOT}/")
Expand Down
28 changes: 20 additions & 8 deletions extension/training/examples/XOR/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
import os

import torch
from executorch.exir import to_edge
from executorch.exir import ExecutorchBackendConfig, to_edge

from executorch.extension.training.examples.XOR.model import Net, TrainingNet
from torch.export import export
from torch.export.experimental import _export_forward_backward


def _export_model():
def _export_model(external_mutable_weights: bool = False):
net = TrainingNet(Net())
x = torch.randn(1, 2)

Expand All @@ -30,7 +30,11 @@ def _export_model():
# Lower the graph to edge dialect.
ep = to_edge(ep)
# Lower the graph to executorch.
ep = ep.to_executorch()
ep = ep.to_executorch(
config=ExecutorchBackendConfig(
external_mutable_weights=external_mutable_weights
)
)
return ep


Expand All @@ -44,19 +48,27 @@ def main() -> None:
"--outdir",
type=str,
required=True,
help="Path to the directory to write xor.pte files to",
help="Path to the directory to write xor.pte and xor.ptd files to",
)
parser.add_argument(
"--external",
action="store_true",
help="Export the model with external weights",
)
args = parser.parse_args()

ep = _export_model()
ep = _export_model(args.external)

# Write out the .pte file.
os.makedirs(args.outdir, exist_ok=True)
outfile = os.path.join(args.outdir, "xor.pte")
with open(outfile, "wb") as fp:
fp.write(
ep.buffer,
)
ep.write_to_file(fp)

if args.external:
# current infra doesnt easily allow renaming this file, so just hackily do it here.
ep._tensor_data["xor"] = ep._tensor_data.pop("_default_external_constant")
ep.write_tensor_data_to_file(args.outdir)


if __name__ == "__main__":
Expand Down
34 changes: 29 additions & 5 deletions extension/training/examples/XOR/train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@ using executorch::extension::training::optimizer::SGDOptions;
using executorch::runtime::Error;
using executorch::runtime::Result;
DEFINE_string(model_path, "xor.pte", "Model serialized in flatbuffer format.");
DEFINE_string(ptd_path, "", "Model weights serialized in flatbuffer format.");

int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (argc != 1) {
if (argc == 0) {
ET_LOG(Error, "Please provide a model path.");
return 1;
} else if (argc > 2) {
std::string msg = "Extra commandline args: ";
for (int i = 1 /* skip argv[0] (program name) */; i < argc; i++) {
for (int i = 2 /* skip argv[0] (pte path) and argv[1] (ptd path) */;
i < argc;
i++) {
msg += argv[i];
}
ET_LOG(Error, "%s", msg.c_str());
Expand All @@ -46,7 +52,21 @@ int main(int argc, char** argv) {
auto loader = std::make_unique<executorch::extension::FileDataLoader>(
std::move(loader_res.get()));

auto mod = executorch::extension::training::TrainingModule(std::move(loader));
std::unique_ptr<executorch::extension::FileDataLoader> ptd_loader = nullptr;
if (!FLAGS_ptd_path.empty()) {
executorch::runtime::Result<executorch::extension::FileDataLoader>
ptd_loader_res =
executorch::extension::FileDataLoader::from(FLAGS_ptd_path.c_str());
if (ptd_loader_res.error() != Error::Ok) {
ET_LOG(Error, "Failed to open ptd file: %s", FLAGS_ptd_path.c_str());
return 1;
}
ptd_loader = std::make_unique<executorch::extension::FileDataLoader>(
std::move(ptd_loader_res.get()));
}

auto mod = executorch::extension::training::TrainingModule(
std::move(loader), nullptr, nullptr, nullptr, std::move(ptd_loader));

// Create full data set of input and labels.
std::vector<std::pair<
Expand All @@ -70,7 +90,10 @@ int main(int argc, char** argv) {
// Get the params and names
auto param_res = mod.named_parameters("forward");
if (param_res.error() != Error::Ok) {
ET_LOG(Error, "Failed to get named parameters");
ET_LOG(
Error,
"Failed to get named parameters, error: %d",
static_cast<int>(param_res.error()));
return 1;
}

Expand Down Expand Up @@ -112,5 +135,6 @@ int main(int argc, char** argv) {
std::string(param.first.data()), param.second});
}

executorch::extension::flat_tensor::save_ptd("xor.ptd", param_map, 16);
executorch::extension::flat_tensor::save_ptd(
"trained_xor.ptd", param_map, 16);
}
Loading