Skip to content

Commit

Permalink
Add unit-testing for savedmodel configuration.
Browse files Browse the repository at this point in the history
  • Loading branch information
David Goodwin committed Nov 14, 2018
1 parent c00d67d commit fadfbbe
Show file tree
Hide file tree
Showing 33 changed files with 382 additions and 70 deletions.
1 change: 0 additions & 1 deletion src/core/model_config_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ ModelConfigManager::GetModelConfigPlatform(
TF_RETURN_IF_ERROR(singleton->GetModelConfigInternal(name, &mc));
*platform = GetPlatform(mc.platform());
singleton->platforms_.emplace(name, *platform);
LOG_INFO << "Got platform for " << name;
} else {
*platform = itr->second;
}
Expand Down
103 changes: 103 additions & 0 deletions src/servables/tensorflow/base_bundle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -657,4 +657,107 @@ operator<<(std::ostream& out, const BaseBundle& pb)
return out;
}

bool
CompareDims(
const tensorflow::TensorShapeProto& model_shape, const DimsList& dims)
{
// The first model dimension can be -1 to serve as a placeholder for
// batch. The batch dim doesn't appear in the configuration 'dims'.
const bool has_batch_dim =
(model_shape.dim().size() >= 1) && (model_shape.dim(0).size() == -1);
if (model_shape.dim().size() != (dims.size() + (has_batch_dim ? 1 : 0))) {
return false;
}

for (int i = 0; i < dims.size(); ++i) {
if (model_shape.dim(i + (has_batch_dim ? 1 : 0)).size() != dims[i]) {
return false;
}
}

return true;
}

bool
CompareDataType(tensorflow::DataType model_dtype, DataType dtype)
{
tensorflow::DataType cdtype = ConvertDataType(dtype);
if (cdtype == tensorflow::DT_INVALID) {
return false;
}

return model_dtype == cdtype;
}

const std::string
DimsDebugString(const DimsList& dims)
{
bool first = true;
std::string str;
str.append("[");
for (int i = 0; i < dims.size(); ++i) {
if (!first) {
str.append(",");
}
str.append(std::to_string(dims[i]));
first = false;
}
str.append("]");
return str;
}

const std::string
DimsDebugString(const tensorflow::TensorShapeProto& dims)
{
bool first = true;
std::string str;
str.append("[");
for (int i = 0; i < dims.dim().size(); ++i) {
if (!first) {
str.append(",");
}
str.append(std::to_string(dims.dim(i).size()));
first = false;
}
str.append("]");
return str;
}

tensorflow::DataType
ConvertDataType(DataType dtype)
{
switch (dtype) {
case DataType::TYPE_INVALID:
return tensorflow::DT_INVALID;
case DataType::TYPE_BOOL:
return tensorflow::DT_BOOL;
case DataType::TYPE_UINT8:
return tensorflow::DT_UINT8;
case DataType::TYPE_UINT16:
return tensorflow::DT_UINT16;
case DataType::TYPE_UINT32:
return tensorflow::DT_UINT32;
case DataType::TYPE_UINT64:
return tensorflow::DT_UINT64;
case DataType::TYPE_INT8:
return tensorflow::DT_INT8;
case DataType::TYPE_INT16:
return tensorflow::DT_INT16;
case DataType::TYPE_INT32:
return tensorflow::DT_INT32;
case DataType::TYPE_INT64:
return tensorflow::DT_INT64;
case DataType::TYPE_FP16:
return tensorflow::DT_HALF;
case DataType::TYPE_FP32:
return tensorflow::DT_FLOAT;
case DataType::TYPE_FP64:
return tensorflow::DT_DOUBLE;
default:
break;
}

return tensorflow::DT_INVALID;
}

}} // namespace nvidia::inferenceserver
19 changes: 19 additions & 0 deletions src/servables/tensorflow/base_bundle.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,23 @@ class BaseBundle : public InferenceServable {

std::ostream& operator<<(std::ostream& out, const BaseBundle& pb);

/// \return true if a TensorFlow shape matches a model configuration
/// shape.
bool CompareDims(
const tensorflow::TensorShapeProto& model_shape, const DimsList& dims);

/// \return true if a TensorFlow data-type matches a model
/// configuration data-type.
bool CompareDataType(tensorflow::DataType model_dtype, DataType dtype);

/// \return the string representation of a model configuration shape.
const std::string DimsDebugString(const DimsList& dims);

/// \return the string representation of a TensorFlow shape.
const std::string DimsDebugString(const tensorflow::TensorShapeProto& dims);

/// \return the TensorFlow data-type that corresponds to a model
/// configuration data-type.
tensorflow::DataType ConvertDataType(DataType dtype);

}} // namespace nvidia::inferenceserver
74 changes: 14 additions & 60 deletions src/servables/tensorflow/savedmodel_bundle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,66 +38,6 @@

namespace nvidia { namespace inferenceserver {

namespace {

bool
CompareDims(
const tensorflow::TensorShapeProto& model_shape, const DimsList& dims)
{
// The first model dimension can be -1 to serve as a placeholder for
// batch. The batch dim doesn't appear in the configuration 'dims'.
const bool has_batch_dim =
(model_shape.dim().size() >= 1) && (model_shape.dim(0).size() == -1);
if (model_shape.dim().size() != (dims.size() + (has_batch_dim ? 1 : 0))) {
return false;
}

for (int i = 0; i < dims.size(); ++i) {
if (model_shape.dim(i + (has_batch_dim ? 1 : 0)).size() != dims[i]) {
return false;
}
}

return true;
}

const std::string
DimsDebugString(const DimsList& dims)
{
bool first = true;
std::string str;
str.append("[");
for (int i = 0; i < dims.size(); ++i) {
if (!first) {
str.append(",");
}
str.append(std::to_string(dims[i]));
first = false;
}
str.append("]");
return str;
}

const std::string
DimsDebugString(const tensorflow::TensorShapeProto& dims)
{
bool first = true;
std::string str;
str.append("[");
for (int i = 0; i < dims.dim().size(); ++i) {
if (!first) {
str.append(",");
}
str.append(std::to_string(dims.dim(i).size()));
first = false;
}
str.append("]");
return str;
}


} // namespace

tensorflow::Status
SavedModelBundle::Init(
const tensorflow::StringPiece& path, const ModelConfig& config)
Expand Down Expand Up @@ -208,6 +148,13 @@ SavedModelBundle::CreateSession(
DimsDebugString(iitr->second.tensor_shape()),
" don't match configuration dims ", DimsDebugString(io.dims()));
}
if (!CompareDataType(iitr->second.dtype(), io.data_type())) {
return tensorflow::errors::InvalidArgument(
"unable to load model '", Name(), "', input '", io.name(),
"' data-type ", tensorflow::DataType_Name(iitr->second.dtype()),
" doesn't match configuration data-type ",
DataType_Name(io.data_type()));
}
}

for (const auto& io : Config().output()) {
Expand All @@ -225,6 +172,13 @@ SavedModelBundle::CreateSession(
DimsDebugString(oitr->second.tensor_shape()),
" don't match configuration dims ", DimsDebugString(io.dims()));
}
if (!CompareDataType(oitr->second.dtype(), io.data_type())) {
return tensorflow::errors::InvalidArgument(
"unable to load model '", Name(), "', output '", io.name(),
"' data-type ", tensorflow::DataType_Name(oitr->second.dtype()),
" doesn't match configuration data-type ",
DataType_Name(io.data_type()));
}
}

*session = bundle->session.release();
Expand Down
16 changes: 15 additions & 1 deletion src/servables/tensorflow/savedmodel_bundle_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,21 @@ TEST_F(SavedModelBundleTest, ModelConfigSanity)
const std::string& path,
const ModelConfig& config) -> tensorflow::Status {
std::unique_ptr<SavedModelBundle> bundle(new SavedModelBundle());
return bundle->Init(path, config);
tensorflow::Status status = bundle->Init(path, config);
if (status.ok()) {
std::unordered_map<std::string, std::string> savedmodel_paths;
std::string filename = "model.savedmodel";
const auto savedmodel_path = tensorflow::io::JoinPath(path, filename);
savedmodel_paths.emplace(
std::piecewise_construct, std::make_tuple(filename),
std::make_tuple(savedmodel_path));

tensorflow::ConfigProto session_config;
status =
bundle->CreateExecutionContexts(session_config, savedmodel_paths);
}

return status;
};

// Standard testing...
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
max_batch_size: 1
input [
{
name: "INPUT0"
data_type: TYPE_INT32
dims: [ 16, 1 ]
},
{
name: "INPUT1"
data_type: TYPE_INT32
dims: [ 16 ]
}
]
output [
{
name: "OUTPUT0"
data_type: TYPE_INT8
dims: [ 16 ]
},
{
name: "OUTPUT1"
data_type: TYPE_INT8
dims: [ 16 ]
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Invalid argument: unable to load model 'bad_input_dims', input 'INPUT0' dims [-1,16] don't match configuration dims [16,1]
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
max_batch_size: 1
input [
{
name: "INPUT0"
data_type: TYPE_INT32
dims: [ 16 ]
},
{
name: "INPUT1"
data_type: TYPE_FP32
dims: [ 16 ]
}
]
output [
{
name: "OUTPUT0"
data_type: TYPE_INT8
dims: [ 16 ]
},
{
name: "OUTPUT1"
data_type: TYPE_INT8
dims: [ 16 ]
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Invalid argument: unable to load model 'bad_input_type', input 'INPUT1' data-type DT_INT32 doesn't match configuration data-type TYPE_FP32
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
max_batch_size: 1
input [
{
name: "INPUT0"
data_type: TYPE_INT32
dims: [ 16 ]
},
{
name: "INPUT1"
data_type: TYPE_INT32
dims: [ 16 ]
}
]
output [
{
name: "OUTPUT0"
data_type: TYPE_INT8
dims: [ 16 ]
},
{
name: "OUTPUT1"
data_type: TYPE_INT8
dims: [ 1 ]
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Invalid argument: unable to load model 'bad_output_dims', output 'OUTPUT1' dims [-1,16] don't match configuration dims [1]
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
max_batch_size: 1
input [
{
name: "INPUT0"
data_type: TYPE_INT32
dims: [ 16 ]
},
{
name: "INPUT1"
data_type: TYPE_INT32
dims: [ 16 ]
}
]
output [
{
name: "OUTPUT0"
data_type: TYPE_INT16
dims: [ 16 ]
},
{
name: "OUTPUT1"
data_type: TYPE_INT8
dims: [ 16 ]
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Invalid argument: unable to load model 'bad_output_type', output 'OUTPUT0' data-type DT_INT8 doesn't match configuration data-type TYPE_INT16
Empty file.
Binary file not shown.
Loading

0 comments on commit fadfbbe

Please sign in to comment.