Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace src tick with with wait and send ids #5603

Merged
merged 37 commits into from
Jul 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
238b840
nn.Graph call and launch impl
chengtbf Jul 23, 2021
b17428a
Fix bug of 1.pybind11 list->vector; 2.GlobalJobDesc in NNGraph Compile
chengtbf Jul 23, 2021
f50fdf1
Fix bug: AutoTick Multi-Client callback notifier lbn
chengtbf Jul 23, 2021
1402aae
Fix bug: callback notifier op check in multi-client
chengtbf Jul 23, 2021
7946369
Fix bug: CheckOpGraph in multi-client
chengtbf Jul 23, 2021
5170170
Fix bug of MultiClientSessionContext new ProfileConf
chengtbf Jul 23, 2021
8e4c148
Fix bug of create NONE type kernel; add note; refine PlanUtil::Popula…
chengtbf Jul 23, 2021
72bbd30
Fix bug: CallbackNotifier kernel dtype with uint8; change tick dtype;…
chengtbf Jul 23, 2021
6fb3101
inputs tensor list
chengtbf Jul 23, 2021
2bb38e2
rollback of hack
chengtbf Jul 23, 2021
e8b62e4
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
chengtbf Jul 23, 2021
926c422
Using TensorTuple for RunLazyNNGraph
chengtbf Jul 23, 2021
6cf7a37
Fix bug: WaitAndSendIdsKernel handle for Multi-Client
chengtbf Jul 23, 2021
bd79523
skip graph runnable test
chengtbf Jul 23, 2021
7de278d
Merge branch 'master' into dev_cc_nn_graph_run
chengtbf Jul 23, 2021
7184c0e
Merge branch 'master' into dev_cc_nn_graph_run
oneflow-ci-bot Jul 23, 2021
3ebd938
refine
chengtbf Jul 23, 2021
4d63535
Merge branch 'master' into dev_cc_nn_graph_run
chengtbf Jul 23, 2021
5039ef6
Merge branch 'master' into dev_cc_nn_graph_run
oneflow-ci-bot Jul 23, 2021
a964b76
Merge branch 'master' into dev_cc_nn_graph_run
oneflow-ci-bot Jul 23, 2021
21f0417
skip test graph optimizer for nn.Graph train job completer error
chengtbf Jul 24, 2021
16844ae
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
leaves-zwx Jul 26, 2021
fe00bcb
debug code
leaves-zwx Jul 27, 2021
01a9dfe
rm InputKernel Forward overwrite
leaves-zwx Jul 27, 2021
9015e5a
rm debug code
leaves-zwx Jul 27, 2021
6e0064a
revert useless change
leaves-zwx Jul 27, 2021
5b6a4cd
revert debug code in graph
leaves-zwx Jul 27, 2021
c967a98
rm skip in test
leaves-zwx Jul 27, 2021
4c67b6e
revert source tick op
leaves-zwx Jul 27, 2021
96a3287
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
leaves-zwx Jul 27, 2021
65860e6
rm useless import
leaves-zwx Jul 27, 2021
88ea937
mv test_graph_relu.py
leaves-zwx Jul 27, 2021
4df677e
correct break
leaves-zwx Jul 27, 2021
285e8f1
change the code of finding tick
leaves-zwx Jul 27, 2021
fb69dcc
Merge branch 'master' into replace_src_tick
leaves-zwx Jul 27, 2021
e75b175
Merge branch 'master' into replace_src_tick
oneflow-ci-bot Jul 27, 2021
7cb541d
Merge branch 'master' into replace_src_tick
oneflow-ci-bot Jul 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions oneflow/core/job_rewriter/autotick.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,8 @@ Maybe<void> MultiClientAddWaitAndSendIds(JobBuilder* job_builder, int64_t machin
parallel_conf.set_device_tag("cpu");
parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id) + ":0");
}

// add wait_and_send_ids op conf
OperatorConf wait_and_send_ids_op_conf;
{
wait_and_send_ids_op_conf.set_name(std::string("System-Src-WaitAndSendIds_") + NewUniqueId());
Expand All @@ -438,14 +440,33 @@ Maybe<void> MultiClientAddWaitAndSendIds(JobBuilder* job_builder, int64_t machin
// wait_and_send_ids_conf->id_list() is unused in multi-client mode.
}
JUST(job_builder->AddOp(parallel_conf, wait_and_send_ids_op_conf));
OperatorConf source_tick_op = JUST(job_builder->OpConf4OpName(src_op_name));
{
CHECK_OR_RETURN(source_tick_op.has_source_tick_conf());
auto* source_tick_op_conf = source_tick_op.mutable_source_tick_conf();
CHECK_OR_RETURN(!source_tick_op_conf->has_wait_in());
source_tick_op_conf->set_wait_in(GenLogicalBlobName(wait_and_send_ids_op_conf.name(), "out"));
}
JUST(job_builder->MutOpOnlyOnce(source_tick_op));

// connect wait_and_send_ids to tick op which was connected to the src tick op
OperatorConf tick_op_conf;
bool find_src_tick_consumer_tick = false;
JUST(job_builder->ForEachOperator([&](const Operator& op) -> Maybe<void> {
// skip if the op is not a tick op
if (!op.op_conf().has_tick_conf()) { return Maybe<void>::Ok(); }
for (const auto& ibn : op.input_bns()) {
const auto& input_lbi = op.BnInOp2Lbi(ibn);
if (input_lbi.op_name() == src_op_name) {
CHECK_OR_RETURN(!find_src_tick_consumer_tick);
tick_op_conf.CopyFrom(op.op_conf());
find_src_tick_consumer_tick = true;
}
}
return Maybe<void>::Ok();
}));
CHECK_OR_RETURN(find_src_tick_consumer_tick);
CHECK_OR_RETURN(tick_op_conf.has_tick_conf());
CHECK_EQ_OR_RETURN(tick_op_conf.tick_conf().tick_size(), 1);
tick_op_conf.mutable_tick_conf()->clear_tick();
tick_op_conf.mutable_tick_conf()->add_tick(
GenLogicalBlobName(wait_and_send_ids_op_conf.name(), "out"));
JUST(job_builder->MutOpOnlyOnce(tick_op_conf));

// erase the src tick op
job_builder->DelOps({src_op_name});
return Maybe<void>::Ok();
}

Expand Down
2 changes: 0 additions & 2 deletions oneflow/core/kernel/input_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ class InputKernel final : public KernelIf<device_type> {
~InputKernel() = default;

private:
void Forward(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override {}
void ForwardDataContent(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override {
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
Expand Down
3 changes: 1 addition & 2 deletions oneflow/core/operator/op_conf.proto
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,7 @@ message DstSubsetTickOpConf {
}

message SourceTickOpConf {
optional string wait_in = 1;
required string out = 2;
required string out = 1;
}

message SinkTickOpConf {
Expand Down
4 changes: 1 addition & 3 deletions oneflow/core/operator/source_tick_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ namespace oneflow {
Maybe<void> SourceTickOp::InitFromOpConf() {
CHECK(op_conf().has_source_tick_conf());
CHECK(op_conf().ctrl_in_op_name().empty());
if (op_conf().source_tick_conf().has_wait_in()) { EnrollInputBn("wait_in", false); }
EnrollOutputBn("out", false);
return Maybe<void>::Ok();
}
Expand All @@ -46,8 +45,7 @@ Maybe<void> SourceTickOp::InferOutBlobDescs(
}

Maybe<void> SourceTickOp::GetSbpSignatures(cfg::SbpSignatureList* sbp_sig_list) const {
auto* sbp_signature = sbp_sig_list->mutable_sbp_signature()->Add();
SbpSignatureBuilder().Broadcast(input_bns()).Broadcast(output_bns()).Build(sbp_signature);
SbpSignatureBuilder().Broadcast(output_bns()).Build(sbp_sig_list->mutable_sbp_signature()->Add());
return Maybe<void>::Ok();
}

Expand Down
15 changes: 6 additions & 9 deletions python/oneflow/test/graph/test_graph_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,21 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import unittest

import numpy as np

import oneflow as flow
import oneflow.framework.graph_build_util as graph_build_util
import oneflow.unittest


@unittest.skip(" nn.Graph cannnot run right now ")
class TestReluGraph(flow.unittest.TestCase):
class TestReluGraph(oneflow.unittest.TestCase):
def test_relu_graph(test_case):
data = np.array([2.0, 1.0, 0.0, -1.0, -2.0])
x = flow.tensor(data, dtype=flow.float32)

MyRelu = flow.nn.ReLU()
y_eager = MyRelu(x)
print("eager out :", y_eager)
# print("eager out :", y_eager)

class ReluGraph(flow.nn.Graph):
def __init__(self):
Expand All @@ -42,8 +38,9 @@ def build(self, x):
return self.cc_relu(x)

relu_g = ReluGraph()
y_lazy = relu_g(x)[0]
print("lazy out :", y_lazy)
y_lazy = relu_g(x)
# print(f"type of lazy y: {type(y_lazy)}")
# print(f"lazy y shape: {y_lazy.shape}, data: {y_lazy}")
test_case.assertTrue(np.array_equal(y_eager.numpy(), y_lazy.numpy()))


Expand Down