Skip to content

Commit

Permalink
Use inmodel_name_ when MakeCallable (#84)
Browse files Browse the repository at this point in the history
  • Loading branch information
kthui authored Dec 8, 2022
1 parent 7d73d0a commit 69aa6a0
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/tensorflow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -973,10 +973,12 @@ ModelState::CreateModel(
std::string io_data_type;
RETURN_IF_ERROR(io.MemberAsString("data_type", &io_data_type));

input_names.push_back(io_names.back().c_str());
const auto& itr = lmodel.input_name_map_.find(io_names.back());
input_names.push_back(
itr != lmodel.input_name_map_.end() ? itr->second.c_str()
: io_names.back().c_str());
input_types.push_back(ConvertDataType(io_data_type));
}

triton::common::TritonJson::Value config_outputs;
RETURN_IF_ERROR(ModelConfig().MemberAsArray("output", &config_outputs));
for (size_t i = 0; i < config_outputs.ArraySize(); i++) {
Expand All @@ -987,9 +989,13 @@ ModelState::CreateModel(
std::string io_data_type;
RETURN_IF_ERROR(io.MemberAsString("data_type", &io_data_type));

output_names.push_back(io_names.back().c_str());
const auto& itr = lmodel.output_name_map_.find(io_names.back());
output_names.push_back(
itr != lmodel.output_name_map_.end() ? itr->second.c_str()
: io_names.back().c_str());
output_types.push_back(ConvertDataType(io_data_type));
}

RETURN_IF_TRITONTF_ERROR(TRITONTF_ModelMakeCallable(
lmodel.tritontf_model_.get(), input_names.data(), input_types.data(),
config_inputs.ArraySize(), output_names.data(), output_types.data(),
Expand Down

0 comments on commit 69aa6a0

Please sign in to comment.