Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stream_context_observer #6143

Merged
merged 8 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion oneflow/core/actor/actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,19 @@ class KernelContextImpl : public KernelContext {
public:
OF_DISALLOW_COPY_AND_MOVE(KernelContextImpl);
explicit KernelContextImpl(const JobDesc* job_desc, DeviceCtx* device_ctx)
: job_desc_(job_desc), device_ctx_(device_ctx), state_(nullptr) {}
: job_desc_(job_desc),
device_ctx_(device_ctx),
state_(nullptr),
stream_kernel_observer_(nullptr) {
auto* stream_context_provider = dynamic_cast<StreamContextProvider*>(device_ctx);
if (stream_context_provider != nullptr) {
auto* kernel_observer_provider =
dynamic_cast<KernelObserverProvider*>(stream_context_provider->GetStreamContext());
if (kernel_observer_provider != nullptr) {
stream_kernel_observer_ = kernel_observer_provider->GetKernelObserver();
}
}
}
~KernelContextImpl() = default;

DeviceCtx* device_ctx() const override { return device_ctx_; }
Expand All @@ -44,6 +56,15 @@ class KernelContextImpl : public KernelContext {

const JobDesc* job_desc() const override { return job_desc_; }

void WillForward(KernelContext* kernel_ctx, const Kernel* kernel) override;
void DidForward(KernelContext* kernel_ctx, const Kernel* kernel) override;

void WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override;
void DidForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override;

void WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override;
void DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override;

void UpdateBnInOp2BlobFn(std::function<Blob*(const std::string&)> fn) {
bn_in_op2blob_fn_ = std::move(fn);
}
Expand All @@ -53,8 +74,51 @@ class KernelContextImpl : public KernelContext {
DeviceCtx* device_ctx_;
std::function<Blob*(const std::string&)> bn_in_op2blob_fn_;
void* state_;
KernelObserver* stream_kernel_observer_;
};

void KernelContextImpl::WillForward(KernelContext* kernel_ctx, const Kernel* kernel) {
Global<KernelObserver>::Get()->WillForward(kernel_ctx, kernel);
if (stream_kernel_observer_ != nullptr) {
stream_kernel_observer_->WillForward(kernel_ctx, kernel);
}
}

void KernelContextImpl::DidForward(KernelContext* kernel_ctx, const Kernel* kernel) {
Global<KernelObserver>::Get()->DidForward(kernel_ctx, kernel);
if (stream_kernel_observer_ != nullptr) {
stream_kernel_observer_->DidForward(kernel_ctx, kernel);
}
}

void KernelContextImpl::WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) {
Global<KernelObserver>::Get()->WillForwardHeader(kernel_ctx, kernel);
if (stream_kernel_observer_ != nullptr) {
stream_kernel_observer_->WillForwardHeader(kernel_ctx, kernel);
}
}

void KernelContextImpl::DidForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) {
Global<KernelObserver>::Get()->DidForwardHeader(kernel_ctx, kernel);
if (stream_kernel_observer_ != nullptr) {
stream_kernel_observer_->DidForwardHeader(kernel_ctx, kernel);
}
}

void KernelContextImpl::WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) {
Global<KernelObserver>::Get()->WillForwardDataContent(kernel_ctx, kernel);
if (stream_kernel_observer_ != nullptr) {
stream_kernel_observer_->WillForwardDataContent(kernel_ctx, kernel);
}
}

void KernelContextImpl::DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) {
Global<KernelObserver>::Get()->DidForwardDataContent(kernel_ctx, kernel);
if (stream_kernel_observer_ != nullptr) {
stream_kernel_observer_->DidForwardDataContent(kernel_ctx, kernel);
}
}

void CheckInplaceRegstDescId(const TaskProto& task_proto) {
HashSet<int64_t> consumed_regst_desc_ids;
for (const auto& pair : task_proto.consumed_regst_desc_id()) {
Expand Down
57 changes: 56 additions & 1 deletion oneflow/core/actor/light_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,19 @@ class LightActor : public ActorBase, public KernelContext {
public:
OF_DISALLOW_COPY_AND_MOVE(LightActor);
explicit LightActor(std::shared_ptr<DeviceCtx> device_ctx)
: thread_(nullptr), device_ctx_(std::move(device_ctx)), job_desc_(nullptr) {}
: thread_(nullptr),
device_ctx_(std::move(device_ctx)),
job_desc_(nullptr),
stream_kernel_observer_(nullptr) {
auto* stream_context_provider = dynamic_cast<StreamContextProvider*>(device_ctx_.get());
if (stream_context_provider != nullptr) {
auto* kernel_observer_provider =
dynamic_cast<KernelObserverProvider*>(stream_context_provider->GetStreamContext());
if (kernel_observer_provider != nullptr) {
stream_kernel_observer_ = kernel_observer_provider->GetKernelObserver();
}
}
}
~LightActor() override {
if (exec_kernel) { kernel_info_[0]->kernel->DestroyState(kernel_info_[0]->state); }
}
Expand Down Expand Up @@ -521,6 +533,48 @@ class LightActor : public ActorBase, public KernelContext {

const JobDesc* job_desc() const override { return job_desc_; }

void WillForward(KernelContext* kernel_ctx, const Kernel* kernel) override {
Global<KernelObserver>::Get()->WillForward(kernel_ctx, kernel);
if (stream_kernel_observer_ != nullptr) {
stream_kernel_observer_->WillForward(kernel_ctx, kernel);
}
}

void DidForward(KernelContext* kernel_ctx, const Kernel* kernel) override {
Global<KernelObserver>::Get()->DidForward(kernel_ctx, kernel);
if (stream_kernel_observer_ != nullptr) {
stream_kernel_observer_->DidForward(kernel_ctx, kernel);
}
}

void WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override {
Global<KernelObserver>::Get()->WillForwardHeader(kernel_ctx, kernel);
if (stream_kernel_observer_ != nullptr) {
stream_kernel_observer_->WillForwardHeader(kernel_ctx, kernel);
}
}

void DidForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) override {
Global<KernelObserver>::Get()->DidForwardHeader(kernel_ctx, kernel);
if (stream_kernel_observer_ != nullptr) {
stream_kernel_observer_->DidForwardHeader(kernel_ctx, kernel);
}
}

void WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override {
Global<KernelObserver>::Get()->WillForwardDataContent(kernel_ctx, kernel);
if (stream_kernel_observer_ != nullptr) {
stream_kernel_observer_->WillForwardDataContent(kernel_ctx, kernel);
}
}

void DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override {
Global<KernelObserver>::Get()->DidForwardDataContent(kernel_ctx, kernel);
if (stream_kernel_observer_ != nullptr) {
stream_kernel_observer_->DidForwardDataContent(kernel_ctx, kernel);
}
}

RegstIndex regst_desc_id_index_;
StateContainer index2state_;
IndexType total_reading_cnt_;
Expand All @@ -542,6 +596,7 @@ class LightActor : public ActorBase, public KernelContext {
std::vector<ActorMsg> async_post_act_msgs_;
std::unique_ptr<TaskProto> task_proto_;
const JobDesc* job_desc_;
KernelObserver* stream_kernel_observer_;
};

std::shared_ptr<DeviceCtx> NewDefaultDeviceCtx(const TaskProto& task_proto,
Expand Down
20 changes: 18 additions & 2 deletions oneflow/core/job/env_global_objects_scope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ limitations under the License.
#include "oneflow/core/comm_network/comm_network.h"
#include "oneflow/core/comm_network/epoll/epoll_comm_network.h"
#include "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h"
#include "oneflow/core/kernel/kernel_observer_manager.h"
#include "oneflow/core/kernel/chain_kernel_observer.h"
#include "oneflow/core/kernel/sync_check_kernel_observer.h"
#include "oneflow/core/kernel/blob_access_checker_kernel_observer.h"
#include "oneflow/core/kernel/profiler_kernel_observer.h"
#ifdef WITH_RDMA
#include "oneflow/core/platform/include/ibv.h"
#endif // WITH_RDMA
Expand Down Expand Up @@ -197,7 +200,20 @@ Maybe<void> EnvGlobalObjectsScope::Init(const EnvProto& env_proto) {
}
#endif // __linux__
}
Global<KernelObserver>::SetAllocated(new KernelObserverManager());
{
std::vector<std::shared_ptr<KernelObserver>> kernel_observers;
if (ParseBooleanFromEnv("ONEFLOW_DEBUG_KERNEL_SYNC_CHECK", false)) {
LOG(WARNING)
<< "Environment variable ONEFLOW_DEBUG_KERNEL_SYNC_CHECK has been set to a truthy "
"value, it will impact performance";
kernel_observers.emplace_back(new SyncCheckKernelObserver());
}
if (!ParseBooleanFromEnv("ONEFLOW_KERNEL_DISABLE_BLOB_ACCESS_CHECKER", false)) {
kernel_observers.emplace_back(new BlobAccessCheckerKernelObserver());
}
kernel_observers.emplace_back(new ProfilerKernelObserver());
Global<KernelObserver>::SetAllocated(new ChainKernelObserver(kernel_observers));
}
return Maybe<void>::Ok();
}

Expand Down
51 changes: 51 additions & 0 deletions oneflow/core/kernel/chain_kernel_observer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/kernel/chain_kernel_observer.h"
#include "oneflow/core/kernel/kernel.h"

namespace oneflow {

void ChainKernelObserver::WillForward(KernelContext* kernel_ctx, const Kernel* kernel) {
for (const auto& observer : kernel_observers_) { observer->WillForward(kernel_ctx, kernel); }
}

void ChainKernelObserver::DidForward(KernelContext* kernel_ctx, const Kernel* kernel) {
for (const auto& observer : kernel_observers_) { observer->DidForward(kernel_ctx, kernel); }
}

void ChainKernelObserver::WillForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) {
for (const auto& observer : kernel_observers_) {
observer->WillForwardHeader(kernel_ctx, kernel);
}
}

void ChainKernelObserver::DidForwardHeader(KernelContext* kernel_ctx, const Kernel* kernel) {
for (const auto& observer : kernel_observers_) { observer->DidForwardHeader(kernel_ctx, kernel); }
}

void ChainKernelObserver::WillForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) {
for (const auto& observer : kernel_observers_) {
observer->WillForwardDataContent(kernel_ctx, kernel);
}
}

void ChainKernelObserver::DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) {
for (const auto& observer : kernel_observers_) {
observer->DidForwardDataContent(kernel_ctx, kernel);
}
}

} // namespace oneflow
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_KERNEL_KERNEL_OBSERVER_MANAGER_H_
#define ONEFLOW_CORE_KERNEL_KERNEL_OBSERVER_MANAGER_H_
#ifndef ONEFLOW_CORE_KERNEL_CHAIN_KERNEL_OBSERVER_H_
#define ONEFLOW_CORE_KERNEL_CHAIN_KERNEL_OBSERVER_H_

#include "oneflow/core/kernel/kernel_observer.h"

namespace oneflow {

class KernelObserverManager final : public KernelObserver {
class ChainKernelObserver final : public KernelObserver {
public:
OF_DISALLOW_COPY_AND_MOVE(KernelObserverManager);
KernelObserverManager();
~KernelObserverManager() override = default;
OF_DISALLOW_COPY_AND_MOVE(ChainKernelObserver);
explicit ChainKernelObserver(std::vector<std::shared_ptr<KernelObserver>> kernel_observers)
: kernel_observers_(std::move(kernel_observers)) {}
~ChainKernelObserver() override = default;

void WillForward(KernelContext* kernel_ctx, const Kernel* kernel) override;
void DidForward(KernelContext* kernel_ctx, const Kernel* kernel) override;
Expand All @@ -36,9 +37,9 @@ class KernelObserverManager final : public KernelObserver {
void DidForwardDataContent(KernelContext* kernel_ctx, const Kernel* kernel) override;

private:
std::vector<std::unique_ptr<KernelObserver>> kernel_observers_;
std::vector<std::shared_ptr<KernelObserver>> kernel_observers_;
};

} // namespace oneflow

#endif // ONEFLOW_CORE_KERNEL_KERNEL_OBSERVER_MANAGER_H_
#endif // ONEFLOW_CORE_KERNEL_CHAIN_KERNEL_OBSERVER_H_
101 changes: 0 additions & 101 deletions oneflow/core/kernel/check_numerics_kernel_observer.cu

This file was deleted.

Loading