Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comm broadcast #6213

Merged
merged 18 commits into from
Sep 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -242,25 +242,17 @@ Maybe<one::UserOpExpr> EagerNcclBroadcast(Symbol<ParallelDesc> parallel_desc, in
.Build();
}

Maybe<one::UserOpExpr> FindOrCreatEagerNcclBroadcastOpExpr(Symbol<ParallelDesc> parallel_desc) {
static thread_local HashMap<Symbol<ParallelDesc>, std::shared_ptr<one::UserOpExpr>>
parallel_desc2eager_nccl_broadcast;
auto iter = parallel_desc2eager_nccl_broadcast.find(parallel_desc);
if (iter == parallel_desc2eager_nccl_broadcast.end()) {
int64_t root = JUST(parallel_desc->MachineId4ParallelId(0));
std::shared_ptr<UserOpExpr> op_expr = JUST(EagerNcclBroadcast(parallel_desc, root));
iter = parallel_desc2eager_nccl_broadcast.emplace(parallel_desc, op_expr).first;
}
return iter->second;
}
auto* CachedEagerNcclBroadcastOpExpr = DECORATE(&EagerNcclBroadcast, ThreadLocal);

} // namespace

Maybe<Tensor> Broadcast(const std::shared_ptr<Tensor>& tensor, Symbol<ParallelDesc> parallel_desc,
bool inplace) {
Maybe<Tensor> Broadcast(const std::shared_ptr<Tensor>& tensor, int64_t src_rank,
Symbol<ParallelDesc> parallel_desc, bool inplace) {
CHECK_OR_RETURN(parallel_desc->containing_current_rank());
if (parallel_desc->parallel_num() == 1 /* no broadcast */) { return tensor; }
std::shared_ptr<UserOpExpr> op_expr = JUST(FindOrCreatEagerNcclBroadcastOpExpr(parallel_desc));
if (JUST(parallel_desc->MachineId4ParallelId(0)) == GlobalProcessCtx::Rank() || inplace) {
int64_t root = JUST(parallel_desc->MachineId4ParallelId(src_rank));
std::shared_ptr<UserOpExpr> op_expr = JUST(CachedEagerNcclBroadcastOpExpr(parallel_desc, root));
if (root == GlobalProcessCtx::Rank() || inplace) {
TensorTuple outputs{tensor};
JUST(OpInterpUtil::Dispatch(*op_expr, {tensor}, &outputs,
one::OpExprInterpContext(AttrMap{}, parallel_desc)));
Expand All @@ -280,7 +272,7 @@ Maybe<Tensor> GetSyncedTensorIfBroadcast(const std::shared_ptr<Tensor>& tensor,
JUST(GetTensorDevice4CurrentProcessCtx(parallel_desc, &parallel_id));
if (!parallel_id.has_value()) { return tensor; }
const auto& broadcast_parallel_desc = JUST(GetBroadcastSubParallelDesc(parallel_desc, nd_sbp));
return Broadcast(tensor, broadcast_parallel_desc, false);
return Broadcast(tensor, /* root */ 0, broadcast_parallel_desc, false);
}

Maybe<Shape> CalcPhysicalShape(Symbol<ConsistentTensorMeta> consistent_tensor_meta) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ namespace one {
class Tensor;

Maybe<void> RunEmptyOp(TensorTuple* outputs);
Maybe<Tensor> Broadcast(const std::shared_ptr<Tensor>& tensor, Symbol<ParallelDesc> parallel_desc,
bool inplace);
Maybe<Tensor> Broadcast(const std::shared_ptr<Tensor>& tensor, int64_t src_rank,
Symbol<ParallelDesc> parallel_desc, bool inplace);

} // namespace one
} // namespace oneflow
2 changes: 1 addition & 1 deletion oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@
bind_python: True

- name: "broadcast"
signature: "Tensor (Tensor x, *, Bool inplace=True) => Broadcast"
signature: "Tensor (Tensor x, *, Int64 src_rank=0, Bool inplace=True) => Broadcast"
bind_python: True

- name: "local_all_reduce"
Expand Down
5 changes: 3 additions & 2 deletions oneflow/core/functional/impl/comm_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,14 @@ auto* CachedEagerNcclS2SOpExpr = DECORATE(&EagerNcclS2S, ThreadLocal);
class BroadcastFunctor {
public:
BroadcastFunctor() = default;
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, bool inplace) const {
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, int64_t src_rank,
bool inplace) const {
const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());
std::string device_type_str = JUST(x->device())->type();
CHECK_OR_RETURN(device_type_str == "cuda" || device_type_str == "cpu");
DeviceType device_type = device_type_str == "cuda" ? DeviceType::kGPU : DeviceType::kCPU;
const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(device_type, rank_group));
return one::Broadcast(x, parallel_desc, inplace);
return one::Broadcast(x, src_rank, parallel_desc, inplace);
}
};

Expand Down
1 change: 1 addition & 0 deletions python/oneflow/comm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
"""
from oneflow.comm.comm_ops import all_reduce
from oneflow.comm.comm_ops import all_gather
from oneflow.comm.comm_ops import broadcast
from oneflow._C import send, recv
32 changes: 32 additions & 0 deletions python/oneflow/comm/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,35 @@ def all_gather(tensor_list, tensor):
)
for i in range(tensor.shape[0]):
tensor_list[i] = tensor[i].to_local()


def broadcast(tensor, src):
"""
Broadcasts the tensor to the whole group.
``tensor`` must have the same number of elements in all processes
participating in the collective.
Args:
tensor (Tensor): Data to be sent if ``src`` is the rank of current
process, and tensor to be used to save received data otherwise.
src (int): Source rank.
.. code-block:: python
>>> # We have 1 process groups, 2 ranks.
>>> import oneflow as flow
>>> tensor = flow.tensor([[1, 2], [3, 4]], device="cuda") + flow.env.get_local_rank()
>>> tensor # doctest: +ONLY_CHECK_RANK_0
tensor([[1, 2],
[3, 4]], device='cuda:0', dtype=oneflow.int64)
>>> tensor # doctest: +ONLY_CHECK_RANK_1
tensor([[2, 3],
[4, 5]], device='cuda:1', dtype=oneflow.int64)
>>> flow.comm.broadcast(tensor, 0)
>>> tensor # doctest: +ONLY_CHECK_RANK_0
tensor([[1, 2],
[3, 4]], device='cuda:0', dtype=oneflow.int64)
>>> tensor # doctest: +ONLY_CHECK_RANK_1
tensor([[1, 2],
[3, 4]], device='cuda:1', dtype=oneflow.int64)
"""
assert isinstance(tensor, flow._oneflow_internal.Tensor)
assert isinstance(src, int)
flow._C.broadcast(tensor, src_rank=src, inplace=True)
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,22 @@ def test_all_gather_1n2d(test_case):
)


class TestBroadCast(flow.unittest.TestCase):
@flow.unittest.skip_unless_1n2d()
def test_broadcast_1n2d(test_case):
if flow.env.get_rank() == 0:
np_arr = np.array([[1, 2], [3, 4]])
elif flow.env.get_rank() == 1:
np_arr = np.array([[4, 5], [6, 7]])
tensor = flow.tensor(np_arr, device="cuda", dtype=flow.int32)
flow.comm.broadcast(tensor, 1)
test_case.assertTrue(np.allclose(tensor.numpy(), np.array([[4, 5], [6, 7]])))

tensor = flow.tensor(np_arr, device="cuda", dtype=flow.int32)
flow.comm.broadcast(tensor, 0)
test_case.assertTrue(np.allclose(tensor.numpy(), np.array([[1, 2], [3, 4]])))


@flow.unittest.skip_unless_1n2d()
class TestDocs(flow.unittest.TestCase):
def test_docs(test_case):
Expand Down