Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tmp compute #8570

Merged
merged 89 commits into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
6e8e9c9
ThreadLocalGuard
lixinqi May 12, 2022
08e9178
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi May 14, 2022
f59d17d
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi May 18, 2022
3eb809a
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi May 18, 2022
55c163c
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi May 20, 2022
8aa2e8f
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 1, 2022
7612597
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 6, 2022
de5f971
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 8, 2022
8e86949
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 9, 2022
2ca0707
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 16, 2022
8537b7e
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 16, 2022
55c5160
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 17, 2022
e643eb1
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 18, 2022
eccdfe6
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 20, 2022
043accc
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 20, 2022
97b0eef
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 20, 2022
1591853
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 23, 2022
ba6f2d7
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 23, 2022
5e1a86a
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 23, 2022
1ee004c
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 24, 2022
e853c71
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 24, 2022
c5afe82
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 24, 2022
14226d6
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 26, 2022
754d6a7
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 27, 2022
acb7c98
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 28, 2022
5916848
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jun 28, 2022
913f6f5
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 1, 2022
fa3867e
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 2, 2022
61bee99
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 3, 2022
7eb2d72
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 4, 2022
dd39ed9
StreamRole::kTmpCompute
lixinqi Jul 5, 2022
a1305bb
SoftSyncStream in InstructionsBuilder::TouchTensors
lixinqi Jul 5, 2022
88a2b2a
Merge branch 'master' into tmp_compute
lixinqi Jul 5, 2022
5862a95
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 6, 2022
29ad00c
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 6, 2022
7297192
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 6, 2022
67e5d6f
Merge branch 'master' into tmp_compute
lixinqi Jul 7, 2022
0a54078
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 8, 2022
cec8a1d
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 11, 2022
b50e236
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 11, 2022
21b3e9e
Merge branch 'master' into tmp_compute
lixinqi Jul 13, 2022
050b5cf
fix conflicts
lixinqi Jul 13, 2022
a6c5d07
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 15, 2022
b6b73a2
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 16, 2022
43197bb
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 16, 2022
4453c58
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 17, 2022
582e11f
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 18, 2022
496e08d
Merge branch 'master' into tmp_compute
lixinqi Jul 18, 2022
5cfeb91
ONEFLOW_AD_PUT_LOSS_ON_TMP_COMPUTE_STREAM
lixinqi Jul 18, 2022
d37b327
Merge branch 'master' into tmp_compute
lixinqi Jul 18, 2022
4001637
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 19, 2022
7fdc675
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 19, 2022
c65ccdc
merge master
lixinqi Jul 19, 2022
b447d9c
Merge branch 'master' into tmp_compute
lixinqi Jul 19, 2022
1555f70
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 20, 2022
44d363f
Merge branch 'master' into tmp_compute
lixinqi Jul 20, 2022
cea5d58
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 22, 2022
30d130b
Merge branch 'master' into tmp_compute
lixinqi Jul 22, 2022
3cdfe94
Merge branch 'tmp_compute' of github.com:Oneflow-Inc/oneflow into tmp…
lixinqi Jul 22, 2022
ccbddef
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 22, 2022
56fc83f
Merge branch 'master' into tmp_compute
lixinqi Jul 22, 2022
c914f2f
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 25, 2022
6b7885f
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 25, 2022
09489b2
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 26, 2022
e52056d
Merge branch 'master' into tmp_compute
lixinqi Jul 26, 2022
ee14204
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 27, 2022
4720413
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 28, 2022
4a94f74
merge master
lixinqi Jul 28, 2022
55dd98b
merge master
lixinqi Jul 28, 2022
97b697d
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 28, 2022
2cccecb
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 28, 2022
755199c
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 29, 2022
31a5022
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 29, 2022
d690538
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 29, 2022
d0b5287
merge master
lixinqi Jul 29, 2022
c731a5f
Merge branch 'master' into tmp_compute
lixinqi Jul 29, 2022
a3a6056
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 29, 2022
dcaacc6
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 30, 2022
700c39a
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Jul 31, 2022
1c6f65f
Merge branch 'master' of github.com:Oneflow-Inc/oneflow
lixinqi Aug 1, 2022
20eeb77
merge master
lixinqi Aug 1, 2022
f03fe1c
AsyncedDevice2Host
lixinqi Aug 1, 2022
07dee86
Merge branch 'tmp_compute' of github.com:Oneflow-Inc/oneflow into tmp…
lixinqi Aug 1, 2022
ded3c18
Merge branch 'master' into tmp_compute
clackhan Aug 1, 2022
8d6bd62
Merge branch 'master' into tmp_compute
mergify[bot] Aug 1, 2022
55d1635
Merge branch 'master' into tmp_compute
mergify[bot] Aug 1, 2022
f817242
Merge branch 'master' into tmp_compute
mergify[bot] Aug 1, 2022
6fcfeeb
Merge branch 'master' into tmp_compute
mergify[bot] Aug 1, 2022
e28c391
Merge branch 'master' into tmp_compute
mergify[bot] Aug 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion oneflow/core/autograd/autograd_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ limitations under the License.
#include <queue>
#include "oneflow/core/autograd/autograd_engine.h"
#include "oneflow/core/autograd/autograd_meta.h"
#include "oneflow/core/framework/stream.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_arg.h"
#include "oneflow/core/framework/tensor_methods.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/tensor_rpc_util.h"
#include "oneflow/core/autograd/autograd_mode.h"
Expand All @@ -29,6 +31,7 @@ limitations under the License.
#include "oneflow/core/framework/global_param_grad_sync_mode.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/common/env_var/autograd.h"

namespace oneflow {
namespace one {
Expand Down Expand Up @@ -114,6 +117,32 @@ Maybe<void> CheckGlobalTensorsMeta(const TensorTuple& tensor_tuple) {
return Maybe<void>::Ok();
}

Maybe<void> TouchInTmpComputeStream(const TensorTuple& inputs) {
for (auto input : inputs) {
if (input->is_global()) { input = JUST(input->cur_rank_phy_tensor()); }
if (input) {
Symbol<Device> device = JUST(input->device());
auto stream = JUST(Stream::New(device, StreamRole::kTmpCompute));
JUST(Touch(input, stream));
}
}
return Maybe<void>::Ok();
}

constexpr static int kSmallTensorThreshold = 1024;

Maybe<TensorTuple> TryCopyForSmallTensor(const TensorTuple& inputs) {
auto outputs = std::make_shared<TensorTuple>();
outputs->reserve(inputs.size());
for (auto input : inputs) {
if (input->shape()->elem_cnt() <= kSmallTensorThreshold) {
input = JUST(functional::Identity(input));
}
outputs->push_back(input);
}
return outputs;
}

} // namespace

Maybe<void> AutogradEngine::RunBackwardAndSaveGrads4LeafTensorIf(const TensorTuple& outputs,
Expand All @@ -123,7 +152,16 @@ Maybe<void> AutogradEngine::RunBackwardAndSaveGrads4LeafTensorIf(const TensorTup
JUST(CheckGlobalTensorsMeta(outputs));
JUST(CheckGlobalTensorsMeta(out_grads));
DisableCheckGlobalTensorMetaScope disable_meta_check;
return RunBackwardAndSaveGrads4LeafTensor(outputs, out_grads, retain_graph, create_graph);
if (ThreadLocalEnvBool<ONEFLOW_AD_PUT_LOSS_ON_TMP_COMPUTE_STREAM>()) {
// Put outputs into kTmpCompute stream for reducing blocking time of outputs[i].numpy() in main
// thread.
auto copied_outputs = JUST(TryCopyForSmallTensor(outputs));
JUST(TouchInTmpComputeStream(outputs));
return RunBackwardAndSaveGrads4LeafTensor(*copied_outputs, out_grads, retain_graph,
create_graph);
} else {
return RunBackwardAndSaveGrads4LeafTensor(outputs, out_grads, retain_graph, create_graph);
}
}

Maybe<TensorTuple> AutogradEngine::RunBackwardAndReturnInputsTensorGradIf(
Expand Down
27 changes: 27 additions & 0 deletions oneflow/core/common/env_var/autograd.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_ENV_VAR_AUTOGRAD_H_
#define ONEFLOW_CORE_COMMON_ENV_VAR_AUTOGRAD_H_

#include "oneflow/core/common/env_var/env_var.h"

namespace oneflow {

DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_AD_PUT_LOSS_ON_TMP_COMPUTE_STREAM, true);

}

#endif // ONEFLOW_CORE_COMMON_ENV_VAR_AUTOGRAD_H_
7 changes: 6 additions & 1 deletion oneflow/core/common/stream_role.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ enum class StreamRole {
kCompute,
kHost2Device,
kDevice2Host,
kAsyncedDevice2Host,
kSyncedLaunchedCommNet,
kAsyncedLaunchedCommNet,
kBarrier,
kCriticalSection,
kLazyJobLauncher,
kPinnedCompute
kPinnedCompute,
kTmpCompute
};

template<typename DerivedT>
Expand All @@ -45,6 +47,8 @@ struct StreamRoleVisitor {
case StreamRole::kCompute: return DerivedT::VisitCompute(std::forward<Args>(args)...);
case StreamRole::kHost2Device: return DerivedT::VisitHost2Device(std::forward<Args>(args)...);
case StreamRole::kDevice2Host: return DerivedT::VisitDevice2Host(std::forward<Args>(args)...);
case StreamRole::kAsyncedDevice2Host:
return DerivedT::VisitAsyncedDevice2Host(std::forward<Args>(args)...);
case StreamRole::kSyncedLaunchedCommNet:
return DerivedT::VisitSyncedLaunchedCommNet(std::forward<Args>(args)...);
case StreamRole::kAsyncedLaunchedCommNet:
Expand All @@ -56,6 +60,7 @@ struct StreamRoleVisitor {
return DerivedT::VisitLazyJobLauncher(std::forward<Args>(args)...);
case StreamRole::kPinnedCompute:
return DerivedT::VisitPinnedCompute(std::forward<Args>(args)...);
case StreamRole::kTmpCompute: return DerivedT::VisitTmpCompute(std::forward<Args>(args)...);
}
LOG(FATAL) << "invalid stream role";
}
Expand Down
11 changes: 9 additions & 2 deletions oneflow/core/framework/instructions_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,12 +409,19 @@ Maybe<void> InstructionsBuilder::ReleaseTensor(
return Maybe<void>::Ok();
}

Maybe<void> InstructionsBuilder::TouchTensors(const vm::EagerBlobObjectListPtr& eager_blob_object) {
Maybe<void> InstructionsBuilder::TouchTensors(
const vm::EagerBlobObjectListPtr& eager_blob_objects) {
Symbol<Device> device = JUST(Device::New("cpu"));
Symbol<Stream> stream = JUST(GetDefaultStreamByDevice(device));
return TouchTensors(eager_blob_objects, stream);
}

Maybe<void> InstructionsBuilder::TouchTensors(const vm::EagerBlobObjectListPtr& eager_blob_objects,
Symbol<Stream> stream) {
JUST(SoftSyncStream(*eager_blob_objects, stream));
auto instruction = intrusive::make_shared<vm::Instruction>(
JUST(Singleton<VirtualMachine>::Get()->GetVmStream(stream)),
std::make_shared<vm::TouchTensorsInstructionPolicy>(*eager_blob_object));
std::make_unique<vm::TouchTensorsInstructionPolicy>(*eager_blob_objects));
instruction_list_->EmplaceBack(std::move(instruction));
return Maybe<void>::Ok();
}
Expand Down
5 changes: 4 additions & 1 deletion oneflow/core/framework/instructions_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ class InstructionsBuilder : public std::enable_shared_from_this<InstructionsBuil

Maybe<void> ReleaseTensor(const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object);

Maybe<void> TouchTensors(const vm::EagerBlobObjectListPtr& eager_blob_object);
Maybe<void> TouchTensors(const vm::EagerBlobObjectListPtr& eager_blob_objects);

Maybe<void> TouchTensors(const vm::EagerBlobObjectListPtr& eager_blob_objects,
Symbol<Stream> stream);

template<typename T>
Maybe<void> SyncAccessBlobByCallback(
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/framework/stream_allocator_is_pinned.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ struct IsStreamAllocatorPinned : public StreamRoleVisitor<IsStreamAllocatorPinne
static bool VisitCompute() { return false; }
static bool VisitHost2Device() { return false; }
static bool VisitDevice2Host() { return false; }
static bool VisitAsyncedDevice2Host() { return VisitDevice2Host(); }
static bool VisitSyncedLaunchedCommNet() { return false; }
static bool VisitAsyncedLaunchedCommNet() { return false; }
static bool VisitBarrier() { return false; }
static bool VisitCriticalSection() { return false; }
static bool VisitLazyJobLauncher() { return false; }
static bool VisitPinnedCompute() { return true; }
static bool VisitTmpCompute() { return false; }
};

} // namespace oneflow
Expand Down
4 changes: 3 additions & 1 deletion oneflow/core/framework/stream_get_stream_role_name.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ struct GetStreamRoleName : public StreamRoleVisitor<GetStreamRoleName> {
static const char* VisitCompute() { return "compute"; }
static const char* VisitHost2Device() { return "h2d"; }
static const char* VisitDevice2Host() { return "d2h"; }
static const char* VisitAsyncedDevice2Host() { return "asynced_d2h"; }
static const char* VisitSyncedLaunchedCommNet() { return "synced_launched_comm_net"; }
static const char* VisitAsyncedLaunchedCommNet() { return "asynced_launched_comm_net"; }
static const char* VisitBarrier() { return "barrier"; }
static const char* VisitCriticalSection() { return "critical_section"; }
static const char* VisitLazyJobLauncher() { return "lazy_job_launcher"; }
static const char* VisitPinnedCompute() { return "pin_memory"; }
static const char* VisitPinnedCompute() { return "pinned_compute"; }
static const char* VisitTmpCompute() { return "tmp_compute"; }
};

} // namespace oneflow
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/framework/stream_is_comm_net_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ struct IsCommNetStream final : public StreamRoleVisitor<IsCommNetStream> {
static bool VisitCompute() { return false; }
static bool VisitHost2Device() { return false; }
static bool VisitDevice2Host() { return false; }
static bool VisitAsyncedDevice2Host() { return VisitDevice2Host(); }
static bool VisitSyncedLaunchedCommNet() { return true; }
static bool VisitAsyncedLaunchedCommNet() { return true; }
static bool VisitBarrier() { return false; }
static bool VisitCriticalSection() { return false; }
static bool VisitLazyJobLauncher() { return false; }
static bool VisitPinnedCompute() { return VisitCompute(); }
static bool VisitTmpCompute() { return VisitCompute(); }
};

} // namespace oneflow
Expand Down
4 changes: 4 additions & 0 deletions oneflow/core/framework/stream_need_soft_sync.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,16 @@ struct NeedSoftSync : public StreamRoleVisitor<NeedSoftSync> {
static bool VisitCompute(DeviceType device_type) { return device_type != kCPU; }
static bool VisitHost2Device(DeviceType) { return false; }
static bool VisitDevice2Host(DeviceType) { return false; }
static bool VisitAsyncedDevice2Host(DeviceType device_type) {
return VisitDevice2Host(device_type);
}
static bool VisitSyncedLaunchedCommNet(DeviceType device_type) { return false; }
static bool VisitAsyncedLaunchedCommNet(DeviceType) { return false; }
static bool VisitBarrier(DeviceType) { return false; }
static bool VisitCriticalSection(DeviceType) { return false; }
static bool VisitLazyJobLauncher(DeviceType) { return false; }
static bool VisitPinnedCompute(DeviceType device_type) { return VisitCompute(device_type); }
static bool VisitTmpCompute(DeviceType device_type) { return false; }
};

} // namespace oneflow
Expand Down
2 changes: 2 additions & 0 deletions oneflow/core/framework/stream_on_independent_thread.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ struct StreamOnIndependentThread : public StreamRoleVisitor<StreamOnIndependentT
static bool VisitCompute() { return false; }
static bool VisitHost2Device() { return false; }
static bool VisitDevice2Host() { return false; }
static bool VisitAsyncedDevice2Host() { return true; }
static bool VisitSyncedLaunchedCommNet() { return false; }
static bool VisitAsyncedLaunchedCommNet() { return false; }
static bool VisitBarrier() { return false; }
static bool VisitCriticalSection() { return true; }
static bool VisitLazyJobLauncher() { return true; }
static bool VisitPinnedCompute() { return VisitCompute(); }
static bool VisitTmpCompute() { return VisitCompute(); }
};

} // namespace oneflow
Expand Down
11 changes: 11 additions & 0 deletions oneflow/core/framework/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,5 +546,16 @@ Maybe<Tensor> Diagonal(const std::shared_ptr<Tensor>& input, const int32_t offse
}

} // namespace view

Maybe<void> Touch(std::shared_ptr<Tensor> input, Symbol<Stream> stream) {
auto eager_blob_objects = std::make_shared<vm::EagerBlobObjectList>();
if (input->is_global()) { input = JUST(input->cur_rank_phy_tensor()); }
if (input) { eager_blob_objects->push_back(JUST(input->eager_blob_object())); }
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
return builder->TouchTensors(eager_blob_objects, stream);
}));
return Maybe<void>::Ok();
}

} // namespace one
} // namespace oneflow
7 changes: 7 additions & 0 deletions oneflow/core/framework/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ limitations under the License.
#include "oneflow/core/framework/tensor.h"

namespace oneflow {

class Stream;

namespace one {

class Tensor;
Expand Down Expand Up @@ -67,7 +70,11 @@ Maybe<Tensor> Diagonal(const std::shared_ptr<Tensor>& input, const int32_t offse
const int32_t dim1, const int32_t dim2);

} // namespace view

Maybe<void> Touch(std::shared_ptr<Tensor> input, Symbol<Stream> stream);

} // namespace one

} // namespace oneflow

#endif // ONEFLOW_CORE_FRAMEWORK_TENSOR_METHOD_H_
10 changes: 10 additions & 0 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1478,13 +1478,23 @@ class CopyFunctor {
JUST(attrs.SetAttr<std::string>("device_type", device_type));
JUST(attrs.SetAttr<int64_t>("device_id", device_id));
JUST(attrs.SetAttr<bool>("pin_memory", pin_memory));
JUST(attrs.SetAttr<bool>("asynced_copy", JUST(GetAsyncedCopy(*x))));

#ifdef WITH_CUDA
if (device_type == "cuda") { InitCudaContextOnce(device_id); }
#endif
return OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs);
}

Maybe<bool> GetAsyncedCopy(const one::Tensor& x) const {
if (!x.is_eager()) { return false; }
if (!x.is_local()) { return false; }
const auto& eager_blob_object = JUST(x.eager_blob_object());
const auto& opt_stream = eager_blob_object->last_used_stream();
if (!opt_stream.has_value()) { return false; }
return JUST(opt_stream)->stream_role() == StreamRole::kTmpCompute;
}

private:
std::shared_ptr<OpExpr> op_;
};
Expand Down
10 changes: 10 additions & 0 deletions oneflow/core/vm/ep_record_event_instruction_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ struct GetRecordEventInstructionPolicy : public StreamRoleVisitor<GetRecordEvent
new vm::EpRecordEventInstructionPolicy(std::forward<Args>(args)...));
}
template<typename... Args>
static Maybe<vm::InstructionPolicy> VisitAsyncedDevice2Host(DeviceType device_type,
Args&&... args) {
return VisitDevice2Host(device_type, std::forward<Args>(args)...);
}
template<typename... Args>
static Maybe<vm::InstructionPolicy> VisitSyncedLaunchedCommNet(DeviceType device_type,
Args&&... args) {
return std::shared_ptr<vm::InstructionPolicy>(
Expand Down Expand Up @@ -140,6 +145,11 @@ struct GetRecordEventInstructionPolicy : public StreamRoleVisitor<GetRecordEvent
return std::shared_ptr<vm::InstructionPolicy>(
new vm::EpRecordEventInstructionPolicy(std::forward<Args>(args)...));
}
template<typename... Args>
static Maybe<vm::InstructionPolicy> VisitTmpCompute(DeviceType device_type, Args&&... args) {
return std::shared_ptr<vm::InstructionPolicy>(
new vm::EpRecordEventInstructionPolicy(std::forward<Args>(args)...));
}
};

} // namespace oneflow
Expand Down
11 changes: 11 additions & 0 deletions oneflow/core/vm/release_tensor_instruction_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ struct MakeReleaseTensorInstructionPolicy
const Optional<vm::Stream*>& stream) {
return Make(data_type, eager_blob_object, stream);
}
static Maybe<vm::InstructionPolicy> VisitAsyncedDevice2Host(
DataType data_type, const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,
const Optional<vm::Stream*>& stream) {
return VisitDevice2Host(data_type, eager_blob_object, stream);
}
static Maybe<vm::InstructionPolicy> VisitSyncedLaunchedCommNet(
DataType data_type, const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,
const Optional<vm::Stream*>& stream) {
Expand Down Expand Up @@ -178,6 +183,12 @@ struct MakeReleaseTensorInstructionPolicy
return VisitCompute(data_type, eager_blob_object, stream);
}

static Maybe<vm::InstructionPolicy> VisitTmpCompute(
DataType data_type, const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,
const Optional<vm::Stream*>& stream) {
return VisitCompute(data_type, eager_blob_object, stream);
}

private:
static Maybe<vm::InstructionPolicy> Make(
DataType data_type, const std::shared_ptr<vm::EagerBlobObject>& eager_blob_object,
Expand Down
6 changes: 6 additions & 0 deletions oneflow/core/vm/stream_get_stream_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ struct CreateStreamPolicy final : public StreamRoleVisitor<CreateStreamPolicy> {
static Maybe<vm::StreamPolicy> VisitDevice2Host(Symbol<Device> device) {
return std::shared_ptr<vm::StreamPolicy>(new vm::EpD2HStreamPolicy(device));
}
static Maybe<vm::StreamPolicy> VisitAsyncedDevice2Host(Symbol<Device> device) {
return VisitDevice2Host(device);
}
static Maybe<vm::StreamPolicy> VisitSyncedLaunchedCommNet(Symbol<Device> device) {
return std::shared_ptr<vm::StreamPolicy>(new vm::EventRecordedEpStreamPolicy(device));
}
Expand All @@ -58,6 +61,9 @@ struct CreateStreamPolicy final : public StreamRoleVisitor<CreateStreamPolicy> {
static Maybe<vm::StreamPolicy> VisitPinnedCompute(Symbol<Device> device) {
return std::shared_ptr<vm::StreamPolicy>(new vm::PinnedEpStreamPolicy(device));
}
static Maybe<vm::StreamPolicy> VisitTmpCompute(Symbol<Device> device) {
return std::shared_ptr<vm::StreamPolicy>(new vm::EventRecordedEpStreamPolicy(device));
}
};

} // namespace oneflow
Expand Down
3 changes: 2 additions & 1 deletion oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -8601,7 +8601,8 @@ def OneFlow_CopyOp : OneFlow_BaseOp<"copy", [NoSideEffect, DeclareOpInterfaceMet
let attrs = (ins
StrAttr:$device_type,
DefaultValuedAttr<SI64Attr, "0">:$device_id,
DefaultValuedAttr<BoolAttr, "false">:$pin_memory
DefaultValuedAttr<BoolAttr, "false">:$pin_memory,
DefaultValuedAttr<BoolAttr, "false">:$asynced_copy
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
Expand Down
Loading