diff --git a/src/core/infer.h b/src/core/infer.h index 035c3115e0..6655510606 100644 --- a/src/core/infer.h +++ b/src/core/infer.h @@ -132,7 +132,6 @@ class HTTPInferRequestProvider : public InferRequestProvider { std::vector> contiguous_buffers_; }; - // Provide inference request outputs class InferResponseProvider { public: @@ -236,7 +235,6 @@ class HTTPInferResponseProvider : public InferResponseProvider { size_t total_raw_byte_size_; }; - // Interface for servables that handle generic inference requests. class InferenceServable { public: diff --git a/src/core/request_status.cc b/src/core/request_status.cc index d9839094f8..8f1fd08125 100644 --- a/src/core/request_status.cc +++ b/src/core/request_status.cc @@ -26,7 +26,6 @@ #include "src/core/request_status.h" - namespace nvidia { namespace inferenceserver { namespace { diff --git a/src/core/server.cc b/src/core/server.cc index 3c1c560bc1..c5d57348a9 100644 --- a/src/core/server.cc +++ b/src/core/server.cc @@ -65,6 +65,7 @@ #include "src/servables/tensorflow/graphdef_bundle.h" #include "src/servables/tensorflow/graphdef_bundle.pb.h" #include "src/servables/tensorflow/savedmodel_bundle.h" +#include "src/servables/tensorflow/savedmodel_bundle.pb.h" #include "src/servables/tensorrt/plan_bundle.h" #include "src/servables/tensorrt/plan_bundle.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -82,7 +83,6 @@ #include "tensorflow_serving/core/availability_preserving_policy.h" #include "tensorflow_serving/core/servable_handle.h" #include "tensorflow_serving/model_servers/server_core.h" -#include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.pb.h" #include "tensorflow_serving/util/net_http/server/public/httpserver.h" #include "tensorflow_serving/util/net_http/server/public/response_code_enum.h" #include "tensorflow_serving/util/net_http/server/public/server_request_interface.h" @@ -535,6 +535,7 @@ InferenceServer::InferenceServer() grpc_port_ = 8001; metrics_port_ = 8002; http_thread_cnt_ = 8; + strict_model_config_ = false; strict_readiness_ = true; model_load_unload_enabled_ = true; profiling_enabled_ = false; @@ -567,6 +568,7 @@ InferenceServer::Init(int argc, char** argv) // On error, the init process will stop. // The difference is if the server will be terminated. bool exit_on_error = true; + bool strict_model_config = strict_model_config_; bool strict_readiness = strict_readiness_; bool allow_model_load_unload = model_load_unload_enabled_; bool allow_profiling = profiling_enabled_; @@ -604,13 +606,20 @@ InferenceServer::Init(int argc, char** argv) "config instead of the default platform."), tensorflow::Flag( "exit-on-error", &exit_on_error, - "Exit the inference server if an error occurs during initialization."), + "Exit the inference server if an error occurs during " + "initialization."), + tensorflow::Flag( + "strict-model-config", &strict_model_config, + "If true model configuration files must be provided and all required " + "configuration settings must be specified. If false the model " + "configuration may be absent or only partially specified and the " + "server will attempt to derive the missing required configuration."), tensorflow::Flag( "strict-readiness", &strict_readiness, "If true /api/health/ready endpoint indicates ready if the server " - "is responsive and all models are available. If false /api/health/ready " - "endpoint indicates ready if server is responsive even if some/all " - "models are unavailable."), + "is responsive and all models are available. If false " + "/api/health/ready endpoint indicates ready if server is responsive even " + "if some/all models are unavailable."), tensorflow::Flag( "allow-model-load-unload", &allow_model_load_unload, "Allow models to be loaded and unloaded dynamically based on changes " @@ -639,13 +648,13 @@ InferenceServer::Init(int argc, char** argv) "Number of threads handling HTTP requests."), tensorflow::Flag( "file-system-poll-secs", &file_system_poll_secs, - "Interval in seconds between each poll of the file " - "system for changes to the model store."), + "Interval in seconds between each poll of the file system for changes to " + "the model store."), tensorflow::Flag( "exit-timeout-secs", &exit_timeout_secs, - "Timeout (in seconds) when exiting to wait for in-flight inferences " - "to finish. After the timeout expires the server exits even if " - "inferences are still in flight."), + "Timeout (in seconds) when exiting to wait for in-flight inferences to " + "finish. After the timeout expires the server exits even if inferences " + "are still in flight."), tensorflow::Flag( "tf-allow-soft-placement", &tf_allow_soft_placement, "Instruct TensorFlow to use CPU implementation of an operation when a " @@ -654,8 +663,8 @@ InferenceServer::Init(int argc, char** argv) "tf-gpu-memory-fraction", &tf_gpu_memory_fraction, "Reserve a portion of GPU memory for TensorFlow models. Default value " "0.0 indicates that TensorFlow should dynamically allocate memory as " - "needed. Value of 1.0 indicates that TensorFlow should allocate all " - "of GPU memory."), + "needed. Value of 1.0 indicates that TensorFlow should allocate all of " + "GPU memory."), }; std::string usage = tensorflow::Flags::Usage(argv[0], flag_list); @@ -681,6 +690,7 @@ InferenceServer::Init(int argc, char** argv) metrics_port_ = allow_metrics ? metrics_port : -1; model_store_path_ = model_store_path; http_thread_cnt_ = http_thread_cnt; + strict_model_config_ = strict_model_config; strict_readiness_ = strict_readiness; model_load_unload_enabled_ = allow_model_load_unload; profiling_enabled_ = allow_profiling; @@ -1315,6 +1325,8 @@ InferenceServer::BuildPlatformConfigMap( { GraphDefBundleSourceAdapterConfig graphdef_config; + graphdef_config.set_autofill(!strict_model_config_); + // Tensorflow session config if (tf_gpu_memory_fraction == 0.0) { graphdef_config.mutable_session_config() @@ -1333,35 +1345,36 @@ InferenceServer::BuildPlatformConfigMap( //// Tensorflow SavedModel { - tfs::SavedModelBundleSourceAdapterConfig saved_model_config; + SavedModelBundleSourceAdapterConfig saved_model_config; + + saved_model_config.set_autofill(!strict_model_config_); if (tf_gpu_memory_fraction == 0.0) { - saved_model_config.mutable_legacy_config() - ->mutable_session_config() + saved_model_config.mutable_session_config() ->mutable_gpu_options() ->set_allow_growth(true); } else { - saved_model_config.mutable_legacy_config() - ->mutable_session_config() + saved_model_config.mutable_session_config() ->mutable_gpu_options() ->set_per_process_gpu_memory_fraction(tf_gpu_memory_fraction); } - saved_model_config.mutable_legacy_config() - ->mutable_session_config() - ->set_allow_soft_placement(tf_allow_soft_placement); + saved_model_config.mutable_session_config()->set_allow_soft_placement( + tf_allow_soft_placement); saved_model_source_adapter_config.PackFrom(saved_model_config); } //// Caffe NetDef { NetDefBundleSourceAdapterConfig netdef_config; + netdef_config.set_autofill(!strict_model_config_); netdef_source_adapter_config.PackFrom(netdef_config); } //// TensorRT { PlanBundleSourceAdapterConfig plan_config; + plan_config.set_autofill(!strict_model_config_); plan_source_adapter_config.PackFrom(plan_config); } @@ -1404,7 +1417,12 @@ InferenceServer::BuildModelConfig( for (const auto& child : real_children) { const auto full_path = tensorflow::io::JoinPath(model_store_path_, child); ModelConfig* model_config = model_configs->add_config(); - TF_RETURN_IF_ERROR(GetNormalizedModelConfig(full_path, model_config)); + + // If enabled, try to automatically generate missing parts of the + // model configuration from the model definition. In all cases + // normalize and validate the config. + TF_RETURN_IF_ERROR( + GetNormalizedModelConfig(full_path, !strict_model_config_, model_config)); TF_RETURN_IF_ERROR(ValidateModelConfig(*model_config, std::string())); // Make sure the name of the model matches the name of the diff --git a/src/core/server.h b/src/core/server.h index 74d37722c2..e3c7f7304f 100644 --- a/src/core/server.h +++ b/src/core/server.h @@ -157,6 +157,7 @@ class InferenceServer { std::string model_store_path_; int http_thread_cnt_; + bool strict_model_config_; bool strict_readiness_; bool model_load_unload_enabled_; bool profiling_enabled_; diff --git a/src/core/utils.cc b/src/core/utils.cc index 506bae0851..8197313b00 100644 --- a/src/core/utils.cc +++ b/src/core/utils.cc @@ -34,6 +34,81 @@ namespace nvidia { namespace inferenceserver { +namespace { + +tensorflow::Status +GetAutoFillPlatform( + const tensorflow::StringPiece& model_name, + const tensorflow::StringPiece& path, std::string* platform) +{ + const std::string model_path(path); + + // Find version subdirectories... + std::vector versions; + TF_RETURN_IF_ERROR( + tensorflow::Env::Default()->GetChildren(model_path, &versions)); + + // GetChildren() returns all descendants instead for cloud storage + // like GCS. In such case we should filter out all non-direct + // descendants. + std::set real_versions; + for (size_t i = 0; i < versions.size(); ++i) { + const std::string& version = versions[i]; + real_versions.insert(version.substr(0, version.find_first_of('/'))); + } + + if (real_versions.empty()) { + return tensorflow::errors::NotFound( + "no version sub-directories for model '", model_name, "'"); + } + + // If a default named file/directory exists in a version + // sub-directory then assume the corresponding platform. + for (const auto& version : real_versions) { + const auto vp = tensorflow::io::JoinPath(model_path, version); + + // TensorRT + if (tensorflow::Env::Default() + ->FileExists(tensorflow::io::JoinPath(vp, kTensorRTPlanFilename)) + .ok()) { + *platform = kTensorRTPlanPlatform; + return tensorflow::Status::OK(); + } + + // TensorFlow + if (tensorflow::Env::Default() + ->FileExists( + tensorflow::io::JoinPath(vp, kTensorFlowSavedModelFilename)) + .ok()) { + *platform = kTensorFlowSavedModelPlatform; + return tensorflow::Status::OK(); + } + if (tensorflow::Env::Default() + ->FileExists( + tensorflow::io::JoinPath(vp, kTensorFlowGraphDefFilename)) + .ok()) { + *platform = kTensorFlowGraphDefPlatform; + return tensorflow::Status::OK(); + } + + // Caffe2 + if (tensorflow::Env::Default() + ->FileExists(tensorflow::io::JoinPath(vp, kCaffe2NetDefFilename)) + .ok()) { + *platform = kCaffe2NetDefPlatform; + return tensorflow::Status::OK(); + } + } + + return tensorflow::errors::NotFound( + "unable to derive platform for model '", model_name, "', the model ", + "definition file must be named '", kTensorRTPlanFilename, "', '", + kTensorFlowGraphDefFilename, "', '", kTensorFlowSavedModelFilename, + "', or '", kCaffe2NetDefFilename, "'"); +} + +} // namespace + tensorflow::Status GetModelVersionFromPath(const tensorflow::StringPiece& path, uint32_t* version) { @@ -50,11 +125,29 @@ GetModelVersionFromPath(const tensorflow::StringPiece& path, uint32_t* version) tensorflow::Status GetNormalizedModelConfig( - const tensorflow::StringPiece& path, ModelConfig* config) + const tensorflow::StringPiece& path, const bool autofill, ModelConfig* config) { + // If 'autofill' then the configuration file can be empty. const auto config_path = tensorflow::io::JoinPath(path, kModelConfigPbTxt); - TF_RETURN_IF_ERROR( - ReadTextProto(tensorflow::Env::Default(), config_path, config)); + if (autofill && !tensorflow::Env::Default()->FileExists(config_path).ok()) { + config->Clear(); + } else { + TF_RETURN_IF_ERROR( + ReadTextProto(tensorflow::Env::Default(), config_path, config)); + } + + // Autofill the platform and name... + if (autofill) { + const std::string model_name(tensorflow::io::Basename(path)); + if (config->name().empty()) { + config->set_name(model_name); + } + + if (config->platform().empty()) { + TF_RETURN_IF_ERROR( + GetAutoFillPlatform(model_name, path, config->mutable_platform())); + } + } // If 'default_model_filename' is not specified set it appropriately // based upon 'platform'. diff --git a/src/core/utils.h b/src/core/utils.h index 7479202a94..803089279d 100644 --- a/src/core/utils.h +++ b/src/core/utils.h @@ -36,9 +36,11 @@ tensorflow::Status GetModelVersionFromPath( // Read a ModelConfig and normalize it as expected by model servables. // 'path' should be the full-path to the directory containing the -// model configuration. +// model configuration. If 'autofill' then attempt to determine any +// missing required configuration from the model definition. tensorflow::Status GetNormalizedModelConfig( - const tensorflow::StringPiece& path, ModelConfig* config); + const tensorflow::StringPiece& path, const bool autofill, + ModelConfig* config); // Validate that a model is specified correctly (excluding inputs and // outputs which are validated via ValidateModelInput() and diff --git a/src/servables/caffe2/netdef_bundle.proto b/src/servables/caffe2/netdef_bundle.proto index d56f39ea4f..efc741a6c2 100644 --- a/src/servables/caffe2/netdef_bundle.proto +++ b/src/servables/caffe2/netdef_bundle.proto @@ -29,4 +29,9 @@ syntax = "proto3"; package nvidia.inferenceserver; // Config proto for NetDefBundleSourceAdapter. -message NetDefBundleSourceAdapterConfig {} +message NetDefBundleSourceAdapterConfig +{ + // Autofill missing required model configuration settings based on + // model definition file. + bool autofill = 1; +} diff --git a/src/servables/caffe2/netdef_bundle_source_adapter.cc b/src/servables/caffe2/netdef_bundle_source_adapter.cc index 5dbe5a4708..f91ac6a1a3 100644 --- a/src/servables/caffe2/netdef_bundle_source_adapter.cc +++ b/src/servables/caffe2/netdef_bundle_source_adapter.cc @@ -43,12 +43,15 @@ namespace { tensorflow::Status CreateNetDefBundle( + const NetDefBundleSourceAdapterConfig& adapter_config, const std::string& path, std::unique_ptr* bundle) { const auto model_path = tensorflow::io::Dirname(path); - ModelConfig config; - TF_RETURN_IF_ERROR(GetNormalizedModelConfig(model_path, &config)); + ModelConfig model_config; + model_config.set_platform(kCaffe2NetDefPlatform); + TF_RETURN_IF_ERROR(GetNormalizedModelConfig( + model_path, adapter_config.autofill(), &model_config)); // Read all the netdef files in 'path'. GetChildren() returns all // descendants instead for cloud storage like GCS, so filter out all @@ -74,7 +77,7 @@ CreateNetDefBundle( // Create the bundle for the model and all the execution contexts // requested for this model. bundle->reset(new NetDefBundle); - tensorflow::Status status = (*bundle)->Init(path, config); + tensorflow::Status status = (*bundle)->Init(path, model_config); if (status.ok()) { status = (*bundle)->CreateExecutionContexts(models); } @@ -97,8 +100,11 @@ NetDefBundleSourceAdapter::Create( LOG_VERBOSE(1) << "Create NetDefBundleSourceAdaptor for config \"" << config.DebugString() << "\""; + Creator creator = std::bind( + &CreateNetDefBundle, config, std::placeholders::_1, std::placeholders::_2); + adapter->reset(new NetDefBundleSourceAdapter( - config, CreateNetDefBundle, SimpleSourceAdapter::EstimateNoResources())); + config, creator, SimpleSourceAdapter::EstimateNoResources())); return tensorflow::Status::OK(); } @@ -114,5 +120,4 @@ namespace tensorflow { namespace serving { REGISTER_STORAGE_PATH_SOURCE_ADAPTER( nvidia::inferenceserver::NetDefBundleSourceAdapter, nvidia::inferenceserver::NetDefBundleSourceAdapterConfig); - }} // namespace tensorflow::serving diff --git a/src/servables/tensorflow/BUILD b/src/servables/tensorflow/BUILD index 98d9a635a6..ea8133a467 100644 --- a/src/servables/tensorflow/BUILD +++ b/src/servables/tensorflow/BUILD @@ -40,6 +40,16 @@ serving_proto_library( ], ) +serving_proto_library( + name = "savedmodel_bundle_proto", + srcs = ["savedmodel_bundle.proto"], + cc_api_version = 2, + deps = [ + "@org_tensorflow//tensorflow/core:protos_all_cc", + "@protobuf_archive//:cc_wkt_protos", + ], +) + cc_library( name = "base_bundle", srcs = ["base_bundle.cc"], @@ -103,12 +113,12 @@ cc_library( hdrs = ["savedmodel_bundle_source_adapter.h"], deps = [ ":savedmodel_bundle", + ":savedmodel_bundle_proto", "//src/core:utils", "@tf_serving//tensorflow_serving/core:loader", "@tf_serving//tensorflow_serving/core:simple_loader", "@tf_serving//tensorflow_serving/core:source_adapter", "@tf_serving//tensorflow_serving/core:storage_path", - "@tf_serving//tensorflow_serving/servables/tensorflow:saved_model_bundle_source_adapter_proto", "@tf_serving//tensorflow_serving/util:optional", "@org_tensorflow//tensorflow/core:core_cpu", "@org_tensorflow//tensorflow/core:lib", diff --git a/src/servables/tensorflow/graphdef_bundle.proto b/src/servables/tensorflow/graphdef_bundle.proto index e9b25f245b..64cf9be17c 100644 --- a/src/servables/tensorflow/graphdef_bundle.proto +++ b/src/servables/tensorflow/graphdef_bundle.proto @@ -36,4 +36,8 @@ message GraphDefBundleSourceAdapterConfig // TensorFlow Session configuration options. // See details at tensorflow/core/protobuf/config.proto. tensorflow.ConfigProto session_config = 1; + + // Autofill missing required model configuration settings based on + // model definition file. + bool autofill = 2; } diff --git a/src/servables/tensorflow/graphdef_bundle_source_adapter.cc b/src/servables/tensorflow/graphdef_bundle_source_adapter.cc index 7194ad1b28..8885826169 100644 --- a/src/servables/tensorflow/graphdef_bundle_source_adapter.cc +++ b/src/servables/tensorflow/graphdef_bundle_source_adapter.cc @@ -49,7 +49,8 @@ CreateGraphDefBundle( const auto model_path = tensorflow::io::Dirname(path); ModelConfig model_config; - TF_RETURN_IF_ERROR(GetNormalizedModelConfig(model_path, &model_config)); + TF_RETURN_IF_ERROR(GetNormalizedModelConfig( + model_path, adapter_config.autofill(), &model_config)); // Read all the graphdef files in 'path'. GetChildren() returns all // descendants instead for cloud storage like GCS, so filter out all diff --git a/src/servables/tensorflow/savedmodel_bundle.proto b/src/servables/tensorflow/savedmodel_bundle.proto new file mode 100644 index 0000000000..27b0f5f404 --- /dev/null +++ b/src/servables/tensorflow/savedmodel_bundle.proto @@ -0,0 +1,43 @@ +// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +syntax = "proto3"; + +import "tensorflow/core/protobuf/config.proto"; + +package nvidia.inferenceserver; + +// Config proto for SavedModelBundleSourceAdapter. +message SavedModelBundleSourceAdapterConfig +{ + // TensorFlow Session configuration options. + // See details at tensorflow/core/protobuf/config.proto. + tensorflow.ConfigProto session_config = 1; + + // Autofill missing required model configuration settings based on + // model definition file. + bool autofill = 2; +} diff --git a/src/servables/tensorflow/savedmodel_bundle_source_adapter.cc b/src/servables/tensorflow/savedmodel_bundle_source_adapter.cc index 35fde1b361..05a97eed01 100644 --- a/src/servables/tensorflow/savedmodel_bundle_source_adapter.cc +++ b/src/servables/tensorflow/savedmodel_bundle_source_adapter.cc @@ -43,13 +43,15 @@ namespace { tensorflow::Status CreateSavedModelBundle( - const tfs::SavedModelBundleSourceAdapterConfig& adapter_config, + const SavedModelBundleSourceAdapterConfig& adapter_config, const std::string& path, std::unique_ptr* bundle) { const auto model_path = tensorflow::io::Dirname(path); ModelConfig model_config; - TF_RETURN_IF_ERROR(GetNormalizedModelConfig(model_path, &model_config)); + model_config.set_platform(kTensorFlowSavedModelPlatform); + TF_RETURN_IF_ERROR(GetNormalizedModelConfig( + model_path, adapter_config.autofill(), &model_config)); // Read all the savedmodel directories in 'path'. GetChildren() // returns all descendants instead for cloud storage like GCS, so @@ -74,7 +76,7 @@ CreateSavedModelBundle( tensorflow::Status status = (*bundle)->Init(path, model_config); if (status.ok()) { status = (*bundle)->CreateExecutionContexts( - adapter_config.legacy_config().session_config(), savedmodel_paths); + adapter_config.session_config(), savedmodel_paths); } if (!status.ok()) { bundle->reset(); @@ -85,10 +87,9 @@ CreateSavedModelBundle( } // namespace - tensorflow::Status SavedModelBundleSourceAdapter::Create( - const tfs::SavedModelBundleSourceAdapterConfig& config, + const SavedModelBundleSourceAdapterConfig& config, std::unique_ptr< tfs::SourceAdapter>>* adapter) @@ -116,6 +117,5 @@ namespace tensorflow { namespace serving { REGISTER_STORAGE_PATH_SOURCE_ADAPTER( nvidia::inferenceserver::SavedModelBundleSourceAdapter, - SavedModelBundleSourceAdapterConfig); - + nvidia::inferenceserver::SavedModelBundleSourceAdapterConfig); }} // namespace tensorflow::serving diff --git a/src/servables/tensorflow/savedmodel_bundle_source_adapter.h b/src/servables/tensorflow/savedmodel_bundle_source_adapter.h index be301588ea..3bba952747 100644 --- a/src/servables/tensorflow/savedmodel_bundle_source_adapter.h +++ b/src/servables/tensorflow/savedmodel_bundle_source_adapter.h @@ -26,12 +26,12 @@ #pragma once #include "src/servables/tensorflow/savedmodel_bundle.h" +#include "src/servables/tensorflow/savedmodel_bundle.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow_serving/core/loader.h" #include "tensorflow_serving/core/simple_loader.h" #include "tensorflow_serving/core/storage_path.h" -#include "tensorflow_serving/servables/tensorflow/saved_model_bundle_source_adapter.pb.h" namespace tfs = tensorflow::serving; @@ -44,7 +44,7 @@ class SavedModelBundleSourceAdapter final tfs::StoragePath, SavedModelBundle> { public: static tensorflow::Status Create( - const tfs::SavedModelBundleSourceAdapterConfig& config, + const SavedModelBundleSourceAdapterConfig& config, std::unique_ptr< SourceAdapter>>* adapter); @@ -56,14 +56,14 @@ class SavedModelBundleSourceAdapter final tfs::SimpleLoaderSourceAdapter; SavedModelBundleSourceAdapter( - const tfs::SavedModelBundleSourceAdapterConfig& config, + const SavedModelBundleSourceAdapterConfig& config, typename SimpleSourceAdapter::Creator creator, typename SimpleSourceAdapter::ResourceEstimator resource_estimator) : SimpleSourceAdapter(creator, resource_estimator), config_(config) { } - const tfs::SavedModelBundleSourceAdapterConfig config_; + const SavedModelBundleSourceAdapterConfig config_; }; }} // namespace nvidia::inferenceserver diff --git a/src/servables/tensorrt/plan_bundle.proto b/src/servables/tensorrt/plan_bundle.proto index a32245bd33..59d1d37bbd 100644 --- a/src/servables/tensorrt/plan_bundle.proto +++ b/src/servables/tensorrt/plan_bundle.proto @@ -29,4 +29,9 @@ syntax = "proto3"; package nvidia.inferenceserver; // Config proto for PlanBundleSourceAdapter. -message PlanBundleSourceAdapterConfig {} +message PlanBundleSourceAdapterConfig +{ + // Autofill missing required model configuration settings based on + // model definition file. + bool autofill = 1; +} diff --git a/src/servables/tensorrt/plan_bundle_source_adapter.cc b/src/servables/tensorrt/plan_bundle_source_adapter.cc index fdc096197d..e8bb533957 100644 --- a/src/servables/tensorrt/plan_bundle_source_adapter.cc +++ b/src/servables/tensorrt/plan_bundle_source_adapter.cc @@ -42,12 +42,16 @@ namespace nvidia { namespace inferenceserver { namespace { tensorflow::Status -CreatePlanBundle(const std::string& path, std::unique_ptr* bundle) +CreatePlanBundle( + const PlanBundleSourceAdapterConfig& adapter_config, const std::string& path, + std::unique_ptr* bundle) { const auto model_path = tensorflow::io::Dirname(path); - ModelConfig config; - TF_RETURN_IF_ERROR(GetNormalizedModelConfig(model_path, &config)); + ModelConfig model_config; + model_config.set_platform(kTensorRTPlanPlatform); + TF_RETURN_IF_ERROR(GetNormalizedModelConfig( + model_path, adapter_config.autofill(), &model_config)); // Read all the plan files in 'path'. GetChildren() returns all // descendants instead for cloud storage like GCS, so filter out all @@ -73,7 +77,7 @@ CreatePlanBundle(const std::string& path, std::unique_ptr* bundle) // Create the bundle for the model and all the execution contexts // requested for this model. bundle->reset(new PlanBundle); - tensorflow::Status status = (*bundle)->Init(path, config); + tensorflow::Status status = (*bundle)->Init(path, model_config); if (status.ok()) { status = (*bundle)->CreateExecutionContexts(models); } @@ -86,7 +90,6 @@ CreatePlanBundle(const std::string& path, std::unique_ptr* bundle) } // namespace - tensorflow::Status PlanBundleSourceAdapter::Create( const PlanBundleSourceAdapterConfig& config, @@ -96,8 +99,11 @@ PlanBundleSourceAdapter::Create( LOG_VERBOSE(1) << "Create PlanBundleSourceAdaptor for config \"" << config.DebugString() << "\""; + Creator creator = std::bind( + &CreatePlanBundle, config, std::placeholders::_1, std::placeholders::_2); + adapter->reset(new PlanBundleSourceAdapter( - config, CreatePlanBundle, SimpleSourceAdapter::EstimateNoResources())); + config, creator, SimpleSourceAdapter::EstimateNoResources())); return tensorflow::Status::OK(); } @@ -113,5 +119,4 @@ namespace tensorflow { namespace serving { REGISTER_STORAGE_PATH_SOURCE_ADAPTER( nvidia::inferenceserver::PlanBundleSourceAdapter, nvidia::inferenceserver::PlanBundleSourceAdapterConfig); - }} // namespace tensorflow::serving diff --git a/src/test/model_config_test_base.cc b/src/test/model_config_test_base.cc index f883ed2572..74aeffcde0 100644 --- a/src/test/model_config_test_base.cc +++ b/src/test/model_config_test_base.cc @@ -43,7 +43,7 @@ ModelConfigTestBase::ValidateInit( result->clear(); ModelConfig config; - tensorflow::Status status = GetNormalizedModelConfig(path, &config); + tensorflow::Status status = GetNormalizedModelConfig(path, autofill, &config); if (!status.ok()) { result->append(status.ToString()); return false;