Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused ImageDecoderRandomCropResize #3644

Merged
merged 10 commits into from
Oct 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ if (BUILD_CUDA)
get_filename_component(cuda_lib_dir ${CUDA_cudart_static_LIBRARY} DIRECTORY)
endif()
set(extra_cuda_libs libculibos.a libcurand_static.a)
if(CUDA_VERSION VERSION_GREATER_EQUAL "10.2")
list(APPEND extra_cuda_libs libnvjpeg_static.a libnppc_static.a libnppig_static.a)
endif()
foreach(extra_cuda_lib ${extra_cuda_libs})
list(APPEND CUDA_LIBRARIES ${cuda_lib_dir}/${extra_cuda_lib})
endforeach()
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/actor/naive_actor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ void NaiveActor::VirtualAsyncSendNaiveProducedRegstMsgToConsumer() {

REGISTER_ACTOR(TaskType::kSliceBoxing, NaiveActor);
REGISTER_ACTOR(TaskType::kBoxingIdentity, NaiveActor);
REGISTER_ACTOR(TaskType::kDecodeH2D, NaiveActor);

} // namespace oneflow
31 changes: 31 additions & 0 deletions oneflow/core/device/cuda_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,27 @@ const char* CurandGetErrorString(curandStatus_t error) {
return "Unknown curand status";
}

#if CUDA_VERSION >= 10020

const char* NvjpegGetErrorString(nvjpegStatus_t error) {
switch (error) {
case NVJPEG_STATUS_SUCCESS: return "NVJPEG_STATUS_SUCCESS";
case NVJPEG_STATUS_NOT_INITIALIZED: return "NVJPEG_STATUS_NOT_INITIALIZED";
case NVJPEG_STATUS_INVALID_PARAMETER: return "NVJPEG_STATUS_INVALID_PARAMETER";
case NVJPEG_STATUS_BAD_JPEG: return "NVJPEG_STATUS_BAD_JPEG";
case NVJPEG_STATUS_JPEG_NOT_SUPPORTED: return "NVJPEG_STATUS_JPEG_NOT_SUPPORTED";
case NVJPEG_STATUS_ALLOCATOR_FAILURE: return "NVJPEG_STATUS_ALLOCATOR_FAILURE";
case NVJPEG_STATUS_EXECUTION_FAILED: return "NVJPEG_STATUS_EXECUTION_FAILED";
case NVJPEG_STATUS_ARCH_MISMATCH: return "NVJPEG_STATUS_ARCH_MISMATCH";
case NVJPEG_STATUS_INTERNAL_ERROR: return "NVJPEG_STATUS_INTERNAL_ERROR";
case NVJPEG_STATUS_IMPLEMENTATION_NOT_SUPPORTED:
return "NVJPEG_STATUS_IMPLEMENTATION_NOT_SUPPORTED";
}
return "Unknown nvjpeg status";
}

#endif

void InitGlobalCudaDeviceProp() {
CHECK(Global<cudaDeviceProp>::Get() == nullptr) << "initialized Global<cudaDeviceProp> twice";
Global<cudaDeviceProp>::New();
Expand Down Expand Up @@ -185,6 +206,16 @@ void NumaAwareCudaMallocHost(int32_t dev, void** ptr, size_t size) {
#endif
}

void CudaDeviceSetCpuAffinity(int32_t dev) {
#ifdef PLATFORM_POSIX
cpu_set_t new_cpu_set;
CudaDeviceGetCpuAffinity(dev, &new_cpu_set);
CHECK_EQ(sched_setaffinity(0, sizeof(cpu_set_t), &new_cpu_set), 0);
#else
UNIMPLEMENTED();
#endif
}

cudaDataType_t GetCudaDataType(DataType val) {
#define MAKE_ENTRY(type_cpp, type_cuda) \
if (val == GetDataType<type_cpp>::value) { return type_cuda; }
Expand Down
27 changes: 26 additions & 1 deletion oneflow/core/device/cuda_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,24 @@ limitations under the License.
#include <cuda_fp16.h>
#include <device_launch_parameters.h>

#if CUDA_VERSION >= 10020

#include <nvjpeg.h>

#endif

namespace oneflow {

const char* CublasGetErrorString(cublasStatus_t error);

const char* CurandGetErrorString(curandStatus_t error);

#if CUDA_VERSION >= 10020

const char* NvjpegGetErrorString(nvjpegStatus_t error);

#endif

#define OF_CUDA_CHECK(condition) \
for (cudaError_t _of_cuda_check_status = (condition); _of_cuda_check_status != cudaSuccess;) \
LOG(FATAL) << "Check failed: " #condition " : " << cudaGetErrorString(_of_cuda_check_status) \
Expand Down Expand Up @@ -63,6 +75,16 @@ const char* CurandGetErrorString(curandStatus_t error);
LOG(FATAL) << "Check failed: " #condition " : " << ncclGetErrorString(_of_nccl_check_status) \
<< " (" << _of_nccl_check_status << ") "

#if CUDA_VERSION >= 10020

#define OF_NVJPEG_CHECK(condition) \
for (nvjpegStatus_t _of_nvjpeg_check_status = (condition); \
_of_nvjpeg_check_status != NVJPEG_STATUS_SUCCESS;) \
LOG(FATAL) << "Check failed: " #condition " : " << NvjpegGetErrorString(_of_nvjpeg_check_status) \
<< " (" << _of_nvjpeg_check_status << ") "

#endif

template<typename T>
void CudaCheck(T error);

Expand Down Expand Up @@ -110,7 +132,7 @@ size_t GetAvailableGpuMemSize(int dev_id);
OF_PP_MAKE_TUPLE_SEQ(kCopyD2H) \
OF_PP_MAKE_TUPLE_SEQ(kNccl) \
OF_PP_MAKE_TUPLE_SEQ(kMix) \
OF_PP_MAKE_TUPLE_SEQ(kMdUpdt)
OF_PP_MAKE_TUPLE_SEQ(kDecodeH2D)

enum class CudaWorkType {
#define DECLARE_CUDA_WORK_TYPE(type) type,
Expand All @@ -126,6 +148,9 @@ void NumaAwareCudaMallocHost(int32_t dev, T** ptr, size_t size) {
NumaAwareCudaMallocHost(dev, reinterpret_cast<void**>(ptr), size);
}

// Set the CPU affinity to the closest processor(s) of a particular GPU.
void CudaDeviceSetCpuAffinity(int32_t dev);

#define CUDA_DATA_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(float, CUDA_R_32F) \
OF_PP_MAKE_TUPLE_SEQ(double, CUDA_R_64F) \
Expand Down
47 changes: 47 additions & 0 deletions oneflow/core/graph/decode_h2d_compute_task_node.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/graph/decode_h2d_compute_task_node.h"
#include "oneflow/core/graph/logical_node.h"

namespace oneflow {

void DecodeH2DCompTaskNode::ConsumeAllRegsts() {
ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst());
}

void DecodeH2DCompTaskNode::ProduceAllRegstsAndBindEdges() {
std::shared_ptr<RegstDesc> out_regst = ProduceRegst("out", false, 2, 2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我感觉使用了这个DecodeH2D以后,我们之前CopyH2D regst num 强制 = 2的trick可以去掉了。这样还能优化boxing v2带来的2倍模型显存的开销?

ForEachOutDataEdge([&](TaskEdge* edge) { edge->AddRegst("out", out_regst); });
ProduceRegst("tmp", false);
}

void DecodeH2DCompTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
std::shared_ptr<Operator> sole_op = this->logical_node()->SoleOp();
node->mut_op() = sole_op;
node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in"));
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
out_regst->AddLbi(sole_op->BnInOp2Lbi(sole_op->SoleObn()));
node->BindBnWithRegst(sole_op->SoleObn(), out_regst);
node->AddBnToRegstAndBindIt(&Operator::tmp_bns, GetProducedRegst("tmp"));
node->InferBlobDescs(parallel_ctx());
}

void DecodeH2DCompTaskNode::InferProducedDataRegstTimeShape() {
NaiveInferProducedDataRegstTimeShape();
}

} // namespace oneflow
48 changes: 48 additions & 0 deletions oneflow/core/graph/decode_h2d_compute_task_node.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_GRAPH_DECODE_H2D_COMPUTE_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_DECODE_H2D_COMPUTE_TASK_NODE_H_

#include "oneflow/core/graph/compute_task_node.h"

namespace oneflow {

class DecodeH2DCompTaskNode final : public CompTaskNode {
public:
OF_DISALLOW_COPY_AND_MOVE(DecodeH2DCompTaskNode);
DecodeH2DCompTaskNode() = default;
~DecodeH2DCompTaskNode() override = default;

void ProduceAllRegstsAndBindEdges() override;
void ConsumeAllRegsts() override;

TaskType GetTaskType() const override { return TaskType::kDecodeH2D; }
CudaWorkType GetCudaWorkType() const override {
#ifdef WITH_CUDA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是否也判断下cuda 版本

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是否也判断下cuda 版本

这里不需要判断版本,DecodeH2D用于可以直接将数据从cpu解码到gpu的op,类似CopyHD,但是CopyHD是异步的,DecodeH2D可能没有办法异步,为了避免影响CopyHD,单独作为了一种Task类型。只有nvjpeg限制cuda版本

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果需要独立的线程,需要重写这类的TaskNode::IsIndependent()方法。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果需要独立的线程,需要重写这类的TaskNode::IsIndependent()方法。

用的是GPU上的特定线程,不是独立线程,GPU目前应该也不支持独立线程吧

return CudaWorkType::kDecodeH2D;
#else
UNIMPLEMENTED();
#endif
}

private:
void BuildExecGphAndRegst() override;
void InferProducedDataRegstTimeShape() override;
};

} // namespace oneflow

#endif // ONEFLOW_CORE_GRAPH_DECODE_H2D_COMPUTE_TASK_NODE_H_
14 changes: 9 additions & 5 deletions oneflow/core/graph/logical_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ void LogicalNode::GenSortedCompTaskNodes(
comp_task_node->set_thrd_id(id_mgr->GetGpuMixThrdId(dev_phy_id));
break;
}
case CudaWorkType::kMdUpdt: {
comp_task_node->set_thrd_id(id_mgr->GetGpuMdUpdtThrdId(dev_phy_id));
case CudaWorkType::kDecodeH2D: {
comp_task_node->set_thrd_id(id_mgr->GetGpuDecodeH2DThrdId(dev_phy_id));
break;
}
default: UNIMPLEMENTED();
Expand Down Expand Up @@ -216,9 +216,6 @@ BldSubTskGphMthd GetMthdForBldSubTskGph(const LogicalNode* src_node, const Logic
}
}
}
if (src_pd->parallel_num() == 1 && dst_pd->parallel_num() == 1) {
return &TaskGraph::BldSubTskGphByOneToOne;
}
std::string k = ConcatTypeName(src_node, dst_node);
auto it = GetFuncForFindBldSubTskGphMthd()->find(k);
if (it == GetFuncForFindBldSubTskGphMthd()->end()) {
Expand All @@ -228,6 +225,9 @@ BldSubTskGphMthd GetMthdForBldSubTskGph(const LogicalNode* src_node, const Logic
it = GetFuncForFindBldSubTskGphMthd()->find("*" + dst_node->TypeName());
}
if (it != GetFuncForFindBldSubTskGphMthd()->end()) { return it->second(src_node, dst_node); }
if (src_pd->parallel_num() == 1 && dst_pd->parallel_num() == 1) {
return &TaskGraph::BldSubTskGphByOneToOne;
}
if (src_pd->parallel_num() == dst_pd->parallel_num()
&& IsConnectedLbisAllSameSbpParallel(src_node, dst_node)) {
return &TaskGraph::BldSubTskGphByOneToOne;
Expand All @@ -247,6 +247,10 @@ REGISTER_BLD_SUB_TSK_GPH_MTHD("DistributeSplit"
"*",
&TaskGraph::BldSubTskGphByPartialOutLbiConnect);

REGISTER_BLD_SUB_TSK_GPH_MTHD("NormalForward"
"DecodeH2D",
&TaskGraph::BldSubTskGphNormalForwardToDecodeH2D);

#define LOGICAL_TYPE_SEQ \
OF_PP_MAKE_TUPLE_SEQ(DistributeConcat, kDataForwardArea) \
OF_PP_MAKE_TUPLE_SEQ(DistributeSplit, kDataForwardArea) \
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/graph/logical_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ limitations under the License.
#include "oneflow/core/graph/acc_compute_task_node.h"
#include "oneflow/core/graph/case_compute_task_node.h"
#include "oneflow/core/graph/esac_compute_task_node.h"
#include "oneflow/core/graph/decode_h2d_compute_task_node.h"

namespace oneflow {

Expand Down Expand Up @@ -208,6 +209,7 @@ DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(RepeatForward);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(Acc);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(Case);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(Esac);
DECLARE_DERIVED_FORWARD_LOGICAL_NODE_WITH_NEW_AREA_ID(DecodeH2D);

} // namespace oneflow

Expand Down
9 changes: 9 additions & 0 deletions oneflow/core/graph/task_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,15 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialOutLbiConnect) {
}
}

DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphNormalForwardToDecodeH2D) {
CHECK_EQ(sorted_src_comp_tasks.size(), sorted_dst_comp_tasks.size());
FOR_RANGE(size_t, i, 0, sorted_src_comp_tasks.size()) {
CompTaskNode* src = sorted_src_comp_tasks.at(i);
CompTaskNode* dst = sorted_dst_comp_tasks.at(i);
Connect<TaskNode>(src, NewEdge(), dst);
}
}

void TaskGraph::BuildTaskPath(
CompTaskNode* src, CompTaskNode* dst,
std::function<TaskNode**(CompTaskNode* src, int64_t machine_id, int32_t mem_zone_id)>
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/graph/task_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBroadcastToBroadcast);
DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialInLbiConnect);
DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByPartialOutLbiConnect);
DECLARE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphNormalForwardToDecodeH2D);

private:
void AcyclicTopoForEachNode(std::function<bool(TaskNode* node)> IsAllowedStartNode,
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/job/id_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ int64_t IDMgr::GetGpuNcclThrdId(int64_t dev_phy_id) const {
int64_t IDMgr::GetGpuMixThrdId(int64_t dev_phy_id) const {
return gpu_device_num_ * 4 + dev_phy_id;
}
int64_t IDMgr::GetGpuMdUpdtThrdId(int64_t dev_phy_id) const {
int64_t IDMgr::GetGpuDecodeH2DThrdId(int64_t dev_phy_id) const {
return gpu_device_num_ * 5 + dev_phy_id;
}
int64_t IDMgr::GetCpuDeviceThrdId(int64_t dev_phy_id) const {
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/job/id_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class IDMgr final {
int64_t GetGpuD2HThrdId(int64_t dev_phy_id) const;
int64_t GetGpuNcclThrdId(int64_t dev_phy_id) const;
int64_t GetGpuMixThrdId(int64_t dev_phy_id) const;
int64_t GetGpuMdUpdtThrdId(int64_t dev_phy_id) const;
int64_t GetGpuDecodeH2DThrdId(int64_t dev_phy_id) const;
int64_t GetCpuDeviceThrdId(int64_t dev_phy_id) const;
int64_t CommNetThrdId() const;
int64_t TickTockThrdId() const;
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/job/scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class Scope final {
const OptMirroredParallel& opt_mirrored_parallel_conf() const {
return scope_proto_.opt_mirrored_parallel_conf();
}
const ScopeProto& scope_proto() const { return scope_proto_; }

private:
Maybe<void> Init();
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/job/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ enum TaskType {
kSliceBoxing = 57;
kCollectiveBoxingGeneric = 58;
kBoxingIdentity = 59;
kDecodeH2D = 60;
};

enum AreaType {
Expand Down
Loading