Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LogicalSliceAssign support full slice sbp #8344

Merged
merged 46 commits into from
Jun 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
92e2a18
feat(SliceOp): slice ops support 2d sbp
wyg1997 May 25, 2022
deaa871
Merge branch 'master' into feat-slice_ops_support_2d_sbp
wyg1997 May 25, 2022
f0ee0ed
fix(SliceOp): fix [B, P] 2d sbp bug
wyg1997 May 25, 2022
1572353
refine error message
wyg1997 May 25, 2022
1e11873
Merge remote-tracking branch 'origin/feat-slice_ops_support_2d_sbp' i…
wyg1997 May 25, 2022
36c2093
fix bug in parallel_num == 1
wyg1997 May 25, 2022
c3f617b
add comment
wyg1997 May 25, 2022
64cb1c9
add warning and format
wyg1997 May 25, 2022
cca0ad9
add NOLINT for boxing check
wyg1997 May 25, 2022
61ebf3e
Merge branch 'master' into feat-slice_ops_support_2d_sbp
wyg1997 May 25, 2022
5a56356
Merge branch 'master' into feat-slice_ops_support_2d_sbp
mergify[bot] May 25, 2022
38eb61c
Merge branch 'master' into feat-slice_ops_support_2d_sbp
wyg1997 May 26, 2022
d724598
Merge branch 'master' into feat-slice_ops_support_2d_sbp
hjchen2 May 26, 2022
0afbab7
Merge branch 'master' into feat-slice_ops_support_2d_sbp
wyg1997 May 26, 2022
f138436
feat(LogicalSliceOps): support all nd_sbp
wyg1997 May 26, 2022
266ebb3
feat(LogicalSlice): support nd_sbp
wyg1997 May 27, 2022
e8ca7d0
Merge remote-tracking branch 'origin/master' into feat-slice_ops_supp…
wyg1997 May 27, 2022
797a6ca
add error message
wyg1997 May 27, 2022
4de5066
fix(AutoTest): fix auto_test bug in module.parameter pass
wyg1997 May 27, 2022
3787d87
auto format by CI
oneflow-ci-bot May 27, 2022
44e7230
fix(LogicalSliceAssign): skip test when 1n1d
wyg1997 May 27, 2022
cbda49b
Merge branch 'feat-slice_ops_support_2d_sbp' of github.com:Oneflow-In…
wyg1997 May 27, 2022
4f9f1f3
Merge remote-tracking branch 'origin/feat-slice_ops_support_2d_sbp' i…
wyg1997 May 30, 2022
63abe10
Merge branch 'master' into feat-logical_slice_ops_support_all_sbp
wyg1997 May 30, 2022
8284bb3
fix SliceParams memset error
wyg1997 May 30, 2022
e20e423
Merge branch 'master' into feat-logical_slice_ops_support_all_sbp
mergify[bot] May 30, 2022
97f356f
remove memset
wyg1997 May 30, 2022
9b2c4ff
add CHECK_JUST
wyg1997 May 30, 2022
9f52bcb
fix(*): make sure split_axis >= 0 or equal to SPLIT_AXIS_FOR_NON_SPLIT
wyg1997 May 30, 2022
8d0b8d0
Merge branch 'master' into feat-logical_slice_ops_support_all_sbp
mergify[bot] May 30, 2022
85a2f60
Merge branch 'master' into feat-logical_slice_ops_support_all_sbp
wyg1997 May 30, 2022
d343539
remove memset
wyg1997 May 31, 2022
f5a6b68
fix spilit_info.axis bug
wyg1997 May 31, 2022
75e0e1b
feat(LogicalSliceOps): support grad
wyg1997 May 31, 2022
750b4af
add logical_slice gradient_funcs
wyg1997 May 31, 2022
1b964e8
Merge remote-tracking branch 'origin/master' into feat-logical_slice_…
wyg1997 Jun 1, 2022
2a4cd88
feat(LogicalSliceAssign): LogicalSliceAssign support full slice sbp
wyg1997 Jun 1, 2022
677059d
Merge remote-tracking branch 'origin/master' into feat-logical_slice_…
wyg1997 Jun 7, 2022
4d97ac1
Merge branch 'master' into feat-logical_slice_assign_support_full_slice
wyg1997 Jun 7, 2022
4115a75
Merge branch 'master' into feat-logical_slice_assign_support_full_slice
wyg1997 Jun 8, 2022
4e339ee
Merge branch 'master' into feat-logical_slice_assign_support_full_slice
wyg1997 Jun 9, 2022
aa0a5c5
auto format by CI
oneflow-ci-bot Jun 9, 2022
a209d19
test(LogicalSlice): fix logical_slice dims
wyg1997 Jun 9, 2022
4570728
Merge branch 'master' into feat-logical_slice_assign_support_full_slice
wyg1997 Jun 9, 2022
1188cc0
Merge branch 'master' into feat-logical_slice_assign_support_full_slice
wyg1997 Jun 9, 2022
1d69560
Merge branch 'master' into feat-logical_slice_assign_support_full_slice
mergify[bot] Jun 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 51 additions & 33 deletions oneflow/user/kernels/slice_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,30 +329,6 @@ DEFINE_STATIC_SWITCH_FUNC(
));
#undef MAKE_WRITE_SLICE_SWITCH_ENTRY

std::shared_ptr<user_op::OpKernelCache> CreateSliceCache(user_op::KernelCacheContext* ctx,
const std::string& large_tensor_name) {
Comment on lines -332 to -333
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段不能通用了,LogicalSlice 和 LogicalSliceAssign 的 Cache 分开推导

SliceContext slice_ctx;
if (ctx->parallel_ctx().parallel_num() == 1) {
// split_axis == SPLIT_AXIS_FOR_NON_SPLIT means the sbp attribute is not 'split'
CHECK_JUST(slice_ctx.PushSplitInfo(SPLIT_AXIS_FOR_NON_SPLIT, 0, 0, 0));
} else {
const NdSbp& in_nd_sbp = ctx->NdSbp4ArgNameAndIndex(large_tensor_name, 0);
const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();
const Shape& logical_shape =
ctx->LogicalTensorDesc4ArgNameAndIndex(large_tensor_name, 0)->shape();
const int64_t parallel_id = ctx->parallel_ctx().parallel_id();
const TensorSliceView& slice_view =
GetTensorSliceView4ParallelId(parallel_hierarchy, in_nd_sbp, logical_shape, parallel_id);
for (int i = 0; i < logical_shape.NumAxes(); ++i) {
const Range& range = slice_view.At(i);
if (range.begin() != 0 || range.end() != logical_shape.At(i)) {
CHECK_JUST(slice_ctx.PushSplitInfo(i, range.begin(), range.end(), logical_shape.At(i)));
}
}
}
return std::make_shared<OpKernelCacheWrapper<SliceContext>>(slice_ctx);
}

template<typename T>
class LogicalSliceKernel final : public user_op::OpKernel {
public:
Expand All @@ -361,7 +337,25 @@ class LogicalSliceKernel final : public user_op::OpKernel {

std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(
user_op::KernelCacheContext* ctx) const override {
return CreateSliceCache(ctx, "x");
SliceContext slice_ctx;
if (ctx->parallel_ctx().parallel_num() == 1) {
// split_axis == SPLIT_AXIS_FOR_NON_SPLIT means the sbp attribute is not 'split'
CHECK_JUST(slice_ctx.PushSplitInfo(SPLIT_AXIS_FOR_NON_SPLIT, 0, 0, 0));
} else {
const NdSbp& in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("x", 0);
const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();
const Shape& logical_shape = ctx->LogicalTensorDesc4ArgNameAndIndex("x", 0)->shape();
const int64_t parallel_id = ctx->parallel_ctx().parallel_id();
const TensorSliceView& slice_view =
GetTensorSliceView4ParallelId(parallel_hierarchy, in_nd_sbp, logical_shape, parallel_id);
for (int i = 0; i < logical_shape.NumAxes(); ++i) {
const Range& range = slice_view.At(i);
if (range.begin() != 0 || range.end() != logical_shape.At(i)) {
CHECK_JUST(slice_ctx.PushSplitInfo(i, range.begin(), range.end(), logical_shape.At(i)));
}
}
}
return std::make_shared<OpKernelCacheWrapper<SliceContext>>(slice_ctx);
}

private:
Expand All @@ -388,15 +382,39 @@ class LogicalSliceAssignKernel final : public user_op::OpKernel {

std::shared_ptr<user_op::OpKernelCache> InitOpKernelCache(
user_op::KernelCacheContext* ctx) const override {
if (ctx->parallel_ctx().parallel_num() > 1) {
const NdSbp& value_nd_sbp = ctx->NdSbp4ArgNameAndIndex("value", 0);
CHECK(std::all_of(value_nd_sbp.sbp_parallel().begin(), value_nd_sbp.sbp_parallel().end(),
[](const SbpParallel& sbp) {
return sbp.has_partial_sum_parallel() || sbp.has_broadcast_parallel();
}))
<< "value's sbp must be broadcast or partial_sum";
SliceContext slice_ctx;
if (ctx->parallel_ctx().parallel_num() == 1) {
// split_axis == SPLIT_AXIS_FOR_NON_SPLIT means the sbp attribute is not 'split'
CHECK_JUST(slice_ctx.PushSplitInfo(SPLIT_AXIS_FOR_NON_SPLIT, 0, 0, 0));
} else {
const Shape& parallel_hierarchy = *ctx->parallel_desc().hierarchy();
NdSbp ref_nd_sbp = ctx->NdSbp4ArgNameAndIndex("ref", 0);
{
const NdSbp value_nd_sbp = ctx->NdSbp4ArgNameAndIndex("value", 0);
// If ref and value both split in the same axis(full slice),
// we can consider the physical tensor is broadcast in this axis.
for (int i = 0; i < parallel_hierarchy.NumAxes(); ++i) {
const SbpParallel& ref_sbp = ref_nd_sbp.sbp_parallel(i);
const SbpParallel& value_sbp = value_nd_sbp.sbp_parallel(i);
if (ref_sbp.has_split_parallel() && value_sbp.has_split_parallel()) {
CHECK_EQ(ref_sbp.split_parallel().axis(), value_sbp.split_parallel().axis());
ref_nd_sbp.mutable_sbp_parallel(i)->clear_split_parallel();
ref_nd_sbp.mutable_sbp_parallel(i)->mutable_broadcast_parallel();
}
}
Comment on lines +393 to +404
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段是这次改动的核心,如果是 FullSlice,这个维度的 SbpParallel 按 Broadcast 来处理

}
const Shape& logical_shape = ctx->LogicalTensorDesc4ArgNameAndIndex("ref", 0)->shape();
const int64_t parallel_id = ctx->parallel_ctx().parallel_id();
const TensorSliceView& slice_view =
GetTensorSliceView4ParallelId(parallel_hierarchy, ref_nd_sbp, logical_shape, parallel_id);
for (int i = 0; i < logical_shape.NumAxes(); ++i) {
const Range& range = slice_view.At(i);
if (range.begin() != 0 || range.end() != logical_shape.At(i)) {
CHECK_JUST(slice_ctx.PushSplitInfo(i, range.begin(), range.end(), logical_shape.At(i)));
}
}
}
return CreateSliceCache(ctx, "ref");
return std::make_shared<OpKernelCacheWrapper<SliceContext>>(slice_ctx);
}

private:
Expand Down
13 changes: 11 additions & 2 deletions oneflow/user/ops/slice_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,21 @@ bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) {
}

/*static*/ Maybe<void> LogicalSliceAssignOp::GetSbp(user_op::SbpContext* ctx) {
const user_op::TensorDesc& ref_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("ref", 0);
FOR_RANGE(int64_t, axis, 0, ref_desc.shape().NumAxes()) {
const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("ref", 0).shape();
const int64_t ndim = x_shape.NumAxes();
const auto& start_vec = ctx->Attr<std::vector<int64_t>>("start");
const auto& stop_vec = ctx->Attr<std::vector<int64_t>>("stop");
const auto& step_vec = ctx->Attr<std::vector<int64_t>>("step");
FOR_RANGE(int64_t, axis, 0, ndim) {
ctx->NewBuilder()
.Split(user_op::OpArg("ref", 0), axis)
.Broadcast(user_op::OpArg("value", 0))
.Split(user_op::OpArg("y", 0), axis)
.Build();
// FullSlice support S+S->S
if (IsFullSlice(start_vec[axis], stop_vec[axis], step_vec[axis], x_shape.At(axis))) {
ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build();
}
}
ctx->NewBuilder()
.PartialSum(user_op::OpArg("ref", 0))
Expand Down Expand Up @@ -260,6 +268,7 @@ bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) {
ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();
return Maybe<void>::Ok();
}

/*static*/ Maybe<void> SliceUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
const auto& x_desc = ctx->InputTensorDesc("x", 0);
const int64_t ndim = x_desc.shape().NumAxes();
Expand Down
12 changes: 6 additions & 6 deletions python/oneflow/test/modules/test_consistent_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,18 @@ def _test_logical_slice_with_bool(test_case, placement, sbp):


def _test_logical_slice_with_grad(test_case, placement, sbp):
x = random_tensor(2, 4, 4, requires_grad=True).oneflow
x = random_tensor(2, 8, 16, requires_grad=True).oneflow
x_numpy = x.detach().cpu().numpy()

class LogicalSliceWithGrad(flow.nn.Module):
def __init__(self):
super().__init__()
self.input_grad = flow.nn.Parameter(flow.zeros(4, 4))
self.input_grad = flow.nn.Parameter(flow.zeros(8, 16))

def forward(self, input):
x = input + self.input_grad
x = x.to_global(placement, sbp)
return x[:, :2]
return x[:, :8]

logical_slice_with_grad = LogicalSliceWithGrad().to_global(
placement, [flow.sbp.broadcast,] * len(sbp)
Expand All @@ -154,10 +154,10 @@ def build(self, x):
y = graph(input)

# output
test_case.assertTrue(np.array_equal(y.numpy(), x_numpy[:, :2]))
test_case.assertTrue(np.array_equal(y.numpy(), x_numpy[:, :8]))
# input_grad
x_grad_np = np.zeros((4, 4))
x_grad_np[:, :2] = 1
x_grad_np = np.zeros((8, 16))
x_grad_np[:, :8] = 1
test_case.assertTrue(
np.array_equal(-graph.module.input_grad.origin.numpy(), x_grad_np)
)
Expand Down
75 changes: 48 additions & 27 deletions python/oneflow/test/modules/test_consistent_slice_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,49 @@


def _test_logical_slice_assign(test_case, placement, sbp):
input = random_tensor(2, 4, 4, requires_grad=True).oneflow
x_numpy = input.detach().cpu().numpy()

input = random_tensor(2, 8, 16, requires_grad=True).oneflow
value = random_tensor(2, 8, 8, requires_grad=True).oneflow
x = (input + 0).to_global(
placement=placement, sbp=sbp
) # add 0 to change to non-leaf tensor
x[:, :2] = 3
y = value.to_global(placement, sbp=sbp)
x[:, :8] = y

ref_np = input.detach().cpu().numpy()
value_np = value.detach().cpu().numpy()

# forward
x_numpy[:, :2] = 3
ref_np[:, :8] = value_np
test_case.assertTrue(x.sbp == sbp)
test_case.assertTrue(np.array_equal(x.numpy(), x_numpy))
test_case.assertTrue(np.array_equal(x.numpy(), ref_np))

# backward
x.sum().backward()
input_grad_np = np.ones((4, 4))
input_grad_np[:, :2] = 0
test_case.assertTrue(np.array_equal(input.grad.numpy(), input_grad_np))
# ref grad
ref_grad_np = np.ones((8, 16))
ref_grad_np[:, :8] = 0
test_case.assertTrue(np.array_equal(input.grad.numpy(), ref_grad_np))
# value grad
value_grad_np = np.ones((8, 8))
test_case.assertTrue(np.array_equal(value.grad.numpy(), value_grad_np))


def _test_graph_logical_slice_assign(test_case, placement, sbp):
x = random_tensor(2, 4, 4, requires_grad=True).oneflow
x_numpy = x.detach().cpu().numpy()
ref = random_tensor(2, 8, 16, requires_grad=True).oneflow
value = random_tensor(2, 8, 8, requires_grad=True).oneflow

class LogicalSliceAssignWithGrad(flow.nn.Module):
def __init__(self):
super().__init__()
self.input_grad = flow.nn.Parameter(flow.zeros(4, 4))
self.ref_grad = flow.nn.Parameter(flow.zeros(8, 16))
self.value_grad = flow.nn.Parameter(flow.zeros(8, 8))

def forward(self, input):
x = input + self.input_grad
def forward(self, ref, value):
x = ref + self.ref_grad
y = value + self.value_grad
x = x.to_global(placement, sbp)
x[:, :2] = 3
y = y.to_global(placement, sbp)
x[:, :8] = y
return x

logical_slice_assign_with_grad = LogicalSliceAssignWithGrad().to_global(
Expand All @@ -72,27 +82,38 @@ def __init__(self):
self.module = logical_slice_assign_with_grad
self.add_optimizer(of_sgd)

def build(self, x):
out = self.module(x)
def build(self, x, y):
out = self.module(x, y)
z = out.sum()
z.backward()
return out

graph = LogicalSliceAssignTrainGraph()

input = x.to_global(placement=placement, sbp=sbp)
y = graph(input)
x = ref.to_global(placement=placement, sbp=sbp)
y = value.to_global(placement=placement, sbp=sbp)
z = graph(x, y)

test_case.assertTrue(z.sbp == sbp)

ref_np = ref.detach().cpu().numpy()
value_np = value.detach().cpu().numpy()

test_case.assertTrue(y.sbp == sbp)
# forward
ref_np[:, :8] = value_np
test_case.assertTrue(np.array_equal(z.numpy(), ref_np))

# output
x_numpy[:, :2] = 3
test_case.assertTrue(np.array_equal(y.numpy(), x_numpy))
# input_grad
x_grad_np = np.ones((4, 4))
x_grad_np[:, :2] = 0
# backward
# ref grad
ref_grad = np.ones((8, 16))
ref_grad[:, :8] = 0
test_case.assertTrue(
np.array_equal(-graph.module.ref_grad.origin.numpy(), ref_grad)
)
# value grad
value_grad = np.ones((8, 8))
test_case.assertTrue(
np.array_equal(-graph.module.input_grad.origin.numpy(), x_grad_np)
np.array_equal(-graph.module.value_grad.origin.numpy(), value_grad)
)


Expand Down