Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions src/tflite_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <fstream>
#include <iostream>
#include <mutex>
#include <shared_mutex>
#include <sstream>
#include <stdexcept>

Expand Down Expand Up @@ -73,15 +74,13 @@ class MLModelServiceTFLite : public vsdk::MLModelService,

/// @brief Stops the MLModelServiceTFLite from running.
void stop() noexcept {
const std::lock_guard<std::mutex> lock(state_lock_);
const std::unique_lock<std::shared_mutex> state_wlock(state_rwmutex_);
state_.reset();
}

void reconfigure(const vsdk::Dependencies& dependencies,
const vsdk::ResourceConfig& configuration) final {


const std::lock_guard<std::mutex> lock(state_lock_);
const std::unique_lock<std::shared_mutex> state_wlock(state_rwmutex_);
check_stopped_inlock_();
state_.reset();
state_ = configure_(dependencies, configuration);
Expand All @@ -90,12 +89,10 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
std::shared_ptr<named_tensor_views> infer(const named_tensor_views& inputs,
const vsdk::ProtoStruct& extra) final {

// We serialize access to the interpreter. We use a
// `unique_lock` instead of a `lock_guard` because we will
// move the lock into the shared state that we return,
// allowing the higher level to effect a direct copy out of
// the tflite buffers while the interpreter is still locked.
std::unique_lock<std::mutex> lock(state_lock_);
// We need to lock state so we are protected against reconfiguration, but
// we don't want to block access to `metadata`. We use a shared lock here,
// and an exclusive lock to protect the interpreter itself, below.
std::shared_lock<std::shared_mutex> state_rlock(state_rwmutex_);
check_stopped_inlock_();

// Ensure that enough inputs were provided.
Expand All @@ -107,6 +104,9 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
throw std::invalid_argument(buffer.str());
}

// Only one thread can actually interact with `state->interpreter` at the same time.
std::unique_lock<std::mutex> interpreter_lock(state_->interpreter_mutex);

// Walk the inputs, and copy the data from each of the input
// tensor views we were given into the associated tflite input
// tensor buffer.
Expand Down Expand Up @@ -153,14 +153,15 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
// we can avoid copying the data by letting the views alias
// the tensorflow tensor buffers and keep the interpreter lock
// held until the gRPC work is done. Note that this means the
// interpreter lock will remain held until the
// state and interpreter locks will remain held until the
// inference_result_type object tracked by the shared pointer
// we return is destroyed. Callers that want to make use of
// the inference results without keeping the interpreter
// locked would need to copy the data out of the views and
// then release the return value.
struct inference_result_type {
std::unique_lock<std::mutex> state_lock;
std::shared_lock<std::shared_mutex> state_rlock;
std::unique_lock<std::mutex> interpreter_lock;
named_tensor_views views;
};
auto inference_result = std::make_shared<inference_result_type>();
Expand All @@ -181,8 +182,9 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
// The views created in the loop above are only valid until
// the interpreter lock is released, so we keep the lock held
// by moving the unique_lock into the inference_result
// object.
inference_result->state_lock = std::move(lock);
// object. We also need the state lock to protect our configuration.
inference_result->state_rlock = std::move(state_rlock);
inference_result->interpreter_lock = std::move(interpreter_lock);

// Finally, construct an aliasing shared_ptr which appears to
// the caller as a shared_ptr to views, but in fact manages
Expand All @@ -196,7 +198,7 @@ class MLModelServiceTFLite : public vsdk::MLModelService,

struct metadata metadata(const vsdk::ProtoStruct& extra) final {
// Just return a copy of our metadata from leased state.
std::lock_guard<std::mutex> lock(state_lock_);
const std::shared_lock<std::shared_mutex> state_rlock(state_rwmutex_);
check_stopped_inlock_();
return state_->metadata;
}
Expand Down Expand Up @@ -487,6 +489,9 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
std::unordered_map<std::string, int> input_tensor_indices_by_name;
std::unordered_map<std::string, int> output_tensor_indices_by_name;

// Protects interpreter_error_data and interpreter
std::mutex interpreter_mutex;

// The `Report` method will overwrite this string.
std::string interpreter_error_data;

Expand Down Expand Up @@ -598,7 +603,7 @@ class MLModelServiceTFLite : public vsdk::MLModelService,
// Accesss to the module state is serialized. All configuration
// state is held in the `state` type to make it easier to destroy
// the current state and replace it with a new one.
std::mutex state_lock_;
std::shared_mutex state_rwmutex_;

// In C++17, this could be `std::optional`.
std::unique_ptr<struct state_> state_;
Expand Down