Skip to content

Commit

Permalink
feat(trt): recompile engine if wrong version is detected
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Nov 7, 2022
1 parent 1132760 commit 0f0bb62
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 4 deletions.
127 changes: 127 additions & 0 deletions src/backends/tensorrt/error_recorder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/**
* DeepDetect
* Copyright (c) 2022 Jolibrain
* Authors: Louis Jean <louis.jean@jolibrain.com>
*
* This file is part of deepdetect.
*
* deepdetect is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* deepdetect is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with deepdetect. If not, see <http://www.gnu.org/licenses/>.
*/

#ifndef DD_TRT_ERROR_RECORDER
#define DD_TRT_ERROR_RECORDER

#include <vector>
#include <string>
#include <mutex>
#include <exception>
#include <atomic>

#include <NvInferRuntimeCommon.h>

namespace dd
{

struct Error
{
nvinfer1::ErrorCode _code;
std::string _desc;
};

class TRTErrorRecorder : public nvinfer1::IErrorRecorder
{
public:
TRTErrorRecorder(std::shared_ptr<spdlog::logger> logger) : _logger(logger)
{
}

~TRTErrorRecorder() noexcept override
{
}

int32_t getNbErrors() const noexcept override
{
return static_cast<int32_t>(_errors.size());
}

nvinfer1::ErrorCode getErrorCode(int32_t errorIdx) const noexcept override
{
return _errors.at(errorIdx)._code;
}

nvinfer1::IErrorRecorder::ErrorDesc
getErrorDesc(int32_t errorIdx) const noexcept override
{
return _errors.at(errorIdx)._desc.c_str();
}

bool hasOverflowed() const noexcept override
{
return false;
}

void clear() noexcept override
{
try
{
std::lock_guard<std::mutex> guard(_errors_mtx);
_errors.clear();
}
catch (std::exception &e)
{
_logger->error("TRTErroRecorder::clear error: {}", e.what());
}
}

// API used by TensorRT to report Error information to the application.
bool
reportError(nvinfer1::ErrorCode val,
nvinfer1::IErrorRecorder::ErrorDesc desc) noexcept override
{
try
{
std::lock_guard<std::mutex> guard(_errors_mtx);
_errors.push_back(Error{ val, std::string(desc) });
_logger->error("TRT Error code={}: {}", static_cast<int32_t>(val),
std::string(desc));
}
catch (std::exception &e)
{
_logger->error("TRTErroRecorder::reportError error: {}", e.what());
}
return true;
}

RefCount incRefCount() noexcept override
{
return ++_ref_count;
}

RefCount decRefCount() noexcept override
{
return --_ref_count;
}

private:
std::vector<Error> _errors;
std::shared_ptr<spdlog::logger> _logger; /**< dd logger */

// Mutex to hold when locking mErrorStack.
std::mutex _errors_mtx;

std::atomic<int32_t> _ref_count{ 0 };
};
}

#endif // DD_TRT_ERROR_RECORDER
28 changes: 24 additions & 4 deletions src/backends/tensorrt/tensorrtlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ namespace dd
initLibNvInferPlugins(&trtLogger, "");
_runtime = std::shared_ptr<nvinfer1::IRuntime>(
nvinfer1::createInferRuntime(trtLogger));
_runtime->setErrorRecorder(new TRTErrorRecorder(this->_logger));

if (ad.has("tensorRTEngineFile"))
_engineFileName = ad.get("tensorRTEngineFile").get<std::string>();
Expand Down Expand Up @@ -594,15 +595,34 @@ namespace dd
trtModelStream.resize(size);
file.read(trtModelStream.data(), size);
file.close();

auto *errors = _runtime->getErrorRecorder();
errors->clear();
_engine = std::shared_ptr<nvinfer1::ICudaEngine>(
_runtime->deserializeCudaEngine(trtModelStream.data(),
trtModelStream.size()));

if (_engine == nullptr)
throw MLLibInternalException(
"Engine could not be deserialized");
bool shouldRecompile = false;
for (int i = 0; i < errors->getNbErrors(); ++i)
{
std::string desc = errors->getErrorDesc(i);
if (desc.find("Version tag does not match")
!= std::string::npos)
{
this->_logger->warn(
"Engine is outdated and will be recompiled");
shouldRecompile = true;
}
}

if (!shouldRecompile)
{
if (_engine == nullptr)
throw MLLibInternalException(
"Engine could not be deserialized");

engineRead = true;
engineRead = true;
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/backends/tensorrt/tensorrtlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include "NvCaffeParser.h"
#include "NvInfer.h"

#include "error_recorder.hpp"

namespace dd
{

Expand Down

0 comments on commit 0f0bb62

Please sign in to comment.