Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Experimental API for setting model name #10518

Merged
merged 8 commits into from
Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add experimental API for editing model name
  • Loading branch information
nums11 committed Feb 10, 2022
commit fcc7e23c5374324d5c2584f0e6438478aaa5d089
3 changes: 3 additions & 0 deletions winml/api/Microsoft.AI.MachineLearning.Experimental.idl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ namespace ROOT_NS.AI.MachineLearning.Experimental {

//! The JoinModel fuses two models by linking outputs from the first model, to inupts of the second.
ROOT_NS.AI.MachineLearning.LearningModel JoinModel(ROOT_NS.AI.MachineLearning.LearningModel other, LearningModelJoinOptions options);

//! The EditModelName function changes the model name to the specified string
nums11 marked this conversation as resolved.
Show resolved Hide resolved
void EditModelName(String model_name);
}

} // namespace Microsoft.AI.MachineLearning.Experimental
5 changes: 5 additions & 0 deletions winml/lib/Api.Experimental/LearningModelExperimental.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,9 @@ void LearningModelExperimental::Save(hstring const& file_name) {
modelp->SaveToFile(file_name);
}

void LearningModelExperimental::EditModelName(hstring const& model_name) {
auto modelp = model_.as<winmlp::LearningModel>();
modelp->EditModelName(model_name);
}

}
2 changes: 2 additions & 0 deletions winml/lib/Api.Experimental/LearningModelExperimental.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct LearningModelExperimental : LearningModelExperimentalT<LearningModelExper

void Save(hstring const& file_name);

void EditModelName(hstring const& model_name);

private:
Microsoft::AI::MachineLearning::LearningModel model_;
};
Expand Down
5 changes: 5 additions & 0 deletions winml/lib/Api.Ort/OnnxruntimeModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ STDMETHODIMP ModelInfo::GetName(const char** out, size_t* len) {
return S_OK;
}

STDMETHODIMP ModelInfo::EditModelName(std::string name) {
name_ = name;
return S_OK;
}

STDMETHODIMP ModelInfo::GetDomain(const char** out, size_t* len) {
*out = domain_.c_str();
*len = domain_.size();
Expand Down
2 changes: 2 additions & 0 deletions winml/lib/Api.Ort/OnnxruntimeModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class ModelInfo : public Microsoft::WRL::RuntimeClass<
(const char** out, size_t* len);
STDMETHOD(GetName)
(const char** out, size_t* len);
STDMETHOD(EditModelName)
(std::string name);
STDMETHOD(GetDomain)
(const char** out, size_t* len);
STDMETHOD(GetDescription)
Expand Down
6 changes: 6 additions & 0 deletions winml/lib/Api/LearningModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,12 @@ LearningModel::OutputFeatures() try {
}
WINML_CATCH_ALL

void LearningModel::EditModelName(const hstring& name) try {
auto name_str = _winml::Strings::UTF8FromHString(name);
WINML_THROW_IF_FAILED(model_info_->EditModelName(name_str));
}
WINML_CATCH_ALL

void LearningModel::Close() try {
// close the model
model_ = nullptr;
Expand Down
2 changes: 2 additions & 0 deletions winml/lib/Api/LearningModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ struct LearningModel : LearningModelT<LearningModel> {
wfc::IVectorView<winml::ILearningModelFeatureDescriptor>
OutputFeatures();

void EditModelName(const hstring& name);

/* IClosable methods. */
void Close();

Expand Down
3 changes: 3 additions & 0 deletions winml/lib/Common/inc/iengine.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ IModelInfo : IUnknown {
STDMETHOD(GetName)
(const char** out, size_t* len) PURE;

STDMETHOD(EditModelName)
(std::string name) PURE;

STDMETHOD(GetDomain)
(const char** out, size_t* len) PURE;

Expand Down
15 changes: 15 additions & 0 deletions winml/test/api/LearningModelSessionAPITest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,20 @@ static void SetIntraOpThreadSpinning() {
WINML_EXPECT_TRUE(allowSpinning);
}

static void EditModelName() {
LearningModel model = nullptr;
WINML_EXPECT_NO_THROW(APITest::LoadModel(L"model.onnx", model));
auto model_name = model.Name();
auto squeezenet_old = to_hstring("squeezenet_old");
WINML_EXPECT_EQUAL(model_name, squeezenet_old);

auto experimental_model = winml_experimental::LearningModelExperimental(model);
auto new_name = to_hstring("new name");
experimental_model.EditModelName(new_name);
model_name = model.Name();
WINML_EXPECT_EQUAL(model_name, new_name);
}


const LearningModelSessionAPITestsApi& getapi() {
static LearningModelSessionAPITestsApi api =
Expand Down Expand Up @@ -1123,6 +1137,7 @@ const LearningModelSessionAPITestsApi& getapi() {
ModelBuilding_STFT,
ModelBuilding_MelSpectrogramOnThreeToneSignal,
ModelBuilding_MelWeightMatrix,
EditModelName
};

if (SkipGpuTests()) {
Expand Down
2 changes: 2 additions & 0 deletions winml/test/api/LearningModelSessionAPITest.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ struct LearningModelSessionAPITestsApi {
VoidTest ModelBuilding_STFT;
VoidTest ModelBuilding_MelSpectrogramOnThreeToneSignal;
VoidTest ModelBuilding_MelWeightMatrix;
VoidTest EditModelName;
};
const LearningModelSessionAPITestsApi& getapi();

Expand Down Expand Up @@ -69,4 +70,5 @@ WINML_TEST(LearningModelSessionAPITests, ModelBuilding_BlackmanWindow)
WINML_TEST(LearningModelSessionAPITests, ModelBuilding_STFT)
WINML_TEST(LearningModelSessionAPITests, ModelBuilding_MelSpectrogramOnThreeToneSignal)
WINML_TEST(LearningModelSessionAPITests, ModelBuilding_MelWeightMatrix)
WINML_TEST(LearningModelSessionAPITests, EditModelName)
WINML_TEST_CLASS_END()