Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Aug 13, 2019
1 parent 177dac6 commit d212bbd
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 24 deletions.
20 changes: 10 additions & 10 deletions onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class UniDirectionalGru {
bool linear_before_reset, Direction direction, const gsl::span<const T>& bias,
const gsl::span<const T>& initial_hidden_state, const ActivationFuncs::Entry& activation_func_f,
const ActivationFuncs::Entry& activation_func_g, float clip,
onnxruntime::concurrency::ThreadPool& ttp);
onnxruntime::concurrency::ThreadPool* ttp);

void Compute(const gsl::span<const T>& inputs, const gsl::span<const int>& sequence_lengths, int num_directions,
const gsl::span<const T>& input_weights, const gsl::span<const T>& recurrent_weights,
Expand Down Expand Up @@ -237,7 +237,7 @@ class UniDirectionalGru {

void AllocateBuffers();

onnxruntime::concurrency::ThreadPool& ttp_;
onnxruntime::concurrency::ThreadPool* ttp_;
};
} // namespace detail

Expand Down Expand Up @@ -375,23 +375,23 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
linear_before_reset_, Direction::kForward, bias_1, initial_hidden_1,
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
clip_, *thread_pool);
clip_, thread_pool);
fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1,
output_1, hidden_output_1);

detail::UniDirectionalGru<T> bw(alloc, seq_length, batch_size, input_size, hidden_size_,
linear_before_reset_, Direction::kReverse, bias_2, initial_hidden_2,
activation_funcs_.Entries()[2],
activation_funcs_.Entries()[3],
clip_, *thread_pool);
clip_, thread_pool);
bw.Compute(input, sequence_lens_span, num_directions_, input_weights_2, recurrent_weights_2,
output_2, hidden_output_2);
} else {
detail::UniDirectionalGru<T> gru_p(alloc, seq_length, batch_size, input_size, hidden_size_,
linear_before_reset_, direction_, bias_1, initial_hidden_1,
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
clip_, *thread_pool);
clip_, thread_pool);
gru_p.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1,
output_1, hidden_output_1);
}
Expand Down Expand Up @@ -420,7 +420,7 @@ UniDirectionalGru<T>::UniDirectionalGru(AllocatorPtr allocator,
const gsl::span<const T>& initial_hidden_state,
const ActivationFuncs::Entry& activation_func_f,
const ActivationFuncs::Entry& activation_func_g,
const float clip, onnxruntime::concurrency::ThreadPool& ttp)
const float clip, onnxruntime::concurrency::ThreadPool* ttp)
: allocator_(allocator),
seq_length_(seq_length),
batch_size_(batch_size),
Expand Down Expand Up @@ -549,7 +549,7 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
input_weights.cbegin(), input_weights.cend(),
input_size_, beta,
outputZRH_.begin(), outputZRH_.end(),
hidden_size_x3, &ttp_);
hidden_size_x3, ttp_);

DumpMatrix("inputs with weights applied", outputZRH_.data(), seq_length_ * batch_size_ * 3, hidden_size_);

Expand Down Expand Up @@ -615,7 +615,7 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
recurrent_weightsZR.cbegin(), recurrent_weightsZR.cend(),
hidden_size_, beta,
outputZRH_.begin() + out_added_offset, outputZRH_.end(),
hidden_size_x3, &ttp_);
hidden_size_x3, ttp_);

DumpMatrix("Ht-1 * R[zr] + Xt*(W[zr]^T)" + seqno_str,
outputZRH_.data() + out_added_offset, batch_size_, hidden_size_x2, 0, hidden_size_x3);
Expand All @@ -631,7 +631,7 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T
hidden_size_, beta,
linear_output_.begin(), linear_output_.end(), // pre: Rbh, post:output
hidden_size_, &ttp_);
hidden_size_, ttp_);

DumpMatrix("Ht-1 * (Rh^T) + Rbh " + seqno_str, linear_output_.data(), batch_size_, hidden_size_);
}
Expand Down Expand Up @@ -702,7 +702,7 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T
hidden_size_, beta,
out_H, outputZRH_.end(),
hidden_size_x3, &ttp_);
hidden_size_x3, ttp_);
}

DumpMatrix("Xt*(Wh^T) + (" + label + ")" + seqno_str, outputZRH_.data() + out_added_offset,
Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class UniDirectionalLstm {
const ActivationFuncs::Entry& activation_func_f, const ActivationFuncs::Entry& activation_func_g,
const ActivationFuncs::Entry& activation_func_h, float clip,
concurrency::ThreadPool& lstm_tp_,
concurrency::ThreadPool& mlas_tp_);
concurrency::ThreadPool* mlas_tp_);

void Compute(const gsl::span<const T>& inputs, const gsl::span<const int>& sequence_lengths, int num_directions,
const gsl::span<const T>& input_weights, const gsl::span<const T>& recurrent_weights,
Expand Down Expand Up @@ -280,7 +280,7 @@ class UniDirectionalLstm {
ActivationInfo<deepcpu::LstmMergeGatesFuncPtr> activation_h_;

concurrency::ThreadPool& lstm_tp_;
concurrency::ThreadPool& mlas_tp_;
concurrency::ThreadPool* mlas_tp_;
};

} // namespace detail
Expand Down Expand Up @@ -315,7 +315,7 @@ DeepCpuLstmOp::Compute(OpKernelContext* context) const {
template <typename T>
Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
auto ctx_internal = static_cast<OpKernelContextInternal*>(&context);
concurrency::ThreadPool& mlas_thread_pool = *ctx_internal->GetOperatorThreadPool();
concurrency::ThreadPool* mlas_thread_pool = ctx_internal->GetOperatorThreadPool();

auto& logger = context.Logger();

Expand Down Expand Up @@ -555,7 +555,7 @@ UniDirectionalLstm<T>::UniDirectionalLstm(AllocatorPtr allocator,
const ActivationFuncs::Entry& activation_func_h,
const float clip,
concurrency::ThreadPool& lstm_tp,
concurrency::ThreadPool& mlas_tp)
concurrency::ThreadPool* mlas_tp)
: allocator_(allocator),
logger_(logger),
seq_length_(seq_length),
Expand Down Expand Up @@ -784,7 +784,7 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
input_weights.cbegin(), input_weights.cend(), // W[iofc]
input_size_, beta,
output_iofc_.begin(), output_iofc_.end(),
hidden_size_x4, &mlas_tp_);
hidden_size_x4, mlas_tp_);

DumpMatrix("Xt*(W[iofc]^T)", output_iofc_.data(), total_rows, hidden_size_x4);

Expand Down Expand Up @@ -833,7 +833,7 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc]
hidden_size_, beta,
step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
hidden_size_x4, &mlas_tp_);
hidden_size_x4, mlas_tp_);

DumpMatrix("Xt*(W[iofc]^T) + Ht-t*R[iofc]" + row_str,
&*step_out_IOFC, local_fused_hidden_rows, hidden_size_x4);
Expand Down Expand Up @@ -911,7 +911,7 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc]
hidden_size_, beta,
step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
hidden_size_x4, &mlas_tp_);
hidden_size_x4, mlas_tp_);

span_T_iter batched_output;
span_T_iter batched_output_end;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/framework/allocation_planner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class PlannerTest : public ::testing::Test {
std::unique_ptr<SequentialExecutionPlan> plan_;

public:
PlannerTest() : model_("test"), graph_(model_.MainGraph()), tp_("test", 1), state_(execution_providers_, false, tp_) {
PlannerTest() : model_("test"), graph_(model_.MainGraph()), tp_("test", 1), state_(execution_providers_, false, &tp_) {
std_kernel_ = KernelDefBuilder().SetName("Transpose").Build();
in_place_kernel_ = KernelDefBuilder().SetName("Clip").MayInplace(0, 0).Build();
CPUExecutionProviderInfo epi;
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/test/framework/execution_frame_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) {
status = kernel_registry_manager.RegisterKernels(execution_providers);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();

SessionState state{execution_providers, true, tp_};
SessionState state{execution_providers, true, &tp_};
state.SetGraphViewer(std::make_unique<GraphViewer>(graph));

OrtValueNameIdxMap& mlvalue_name_idx_map{state.GetOrtValueNameIdxMap()};
Expand Down Expand Up @@ -145,7 +145,7 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) {
execution_providers.Add(xp_typ, std::move(cpu_xp));
EXPECT_TRUE(kernel_registry_manager.RegisterKernels(execution_providers).IsOK());

SessionState state{execution_providers, true, tp_};
SessionState state{execution_providers, true, &tp_};
state.SetGraphViewer(std::make_unique<GraphViewer>(graph));

OrtValueNameIdxMap& mlvalue_name_idx_map{state.GetOrtValueNameIdxMap()};
Expand Down Expand Up @@ -197,7 +197,7 @@ TEST_F(ExecutionFrameTest, MemPatternTest) {
execution_providers.Add(xp_type, std::move(cpu_xp));
kernel_registry_manager.RegisterKernels(execution_providers);
//1. prepare input
SessionState state{execution_providers, true, tp_};
SessionState state{execution_providers, true, &tp_};
state.SetGraphViewer(std::make_unique<GraphViewer>(graph));

OrtValueNameIdxMap& mlvalue_name_idx_map{state.GetOrtValueNameIdxMap()};
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/framework/session_state_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ TEST(SessionStateTest, AddGetKernelTest) {
.SetDoc("Input variable.")
.Output(0, "output_1", "docstr for output_1.", "tensor(int32)");
ExecutionProviders execution_providers;
SessionState s{execution_providers, true, tp};
SessionState s{execution_providers, true, &tp};

onnxruntime::Model model("graph_1");
auto& graph = model.MainGraph();
Expand Down Expand Up @@ -103,7 +103,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {
status = krm.RegisterKernels(execution_providers);
ASSERT_TRUE(status.IsOK()) << status;

SessionState session_state(execution_providers, param.enable_mem_pattern, tp);
SessionState session_state(execution_providers, param.enable_mem_pattern, &tp);
SessionStateInitializer session_initializer(param.enable_mem_pattern, ToWideString(model_path), graph,
session_state, execution_providers, krm);

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/providers/memcpy_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ TEST(MemcpyTest, copy1) {
CPUExecutionProviderInfo epi;
auto st = execution_providers.Add(onnxruntime::kCpuExecutionProvider, std::make_unique<CPUExecutionProvider>(epi));
ASSERT_TRUE(st.IsOK()) << st.ErrorMessage();
SessionState s{execution_providers, true, tp};
SessionState s{execution_providers, true, &tp};
s.SetLogger(logging::LoggingManager::DefaultLogger());
KernelRegistryManager kernel_registry_manager;
kernel_registry_manager.RegisterKernels(execution_providers);
Expand Down

0 comments on commit d212bbd

Please sign in to comment.