Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate cast_to_static_shape to user op #4095

Merged
merged 8 commits into from
Jan 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,25 @@ const AMPList& AutoMixedPrecisionLists::GrayList() {

const AMPList& AutoMixedPrecisionLists::ClearList() {
// TODO(niuchong): identity, tuple_identity, keep_header_only?
static AMPList clear_list = {
"gather", "max_pool_1d", "max_pool_2d", "max_pool_3d", "reshape", "relu",
"transpose", "random_mask_like", "concat", "pad", "same_padding", "tril",
"slice", "fused_scale_tril", "identity", "flatten", "squeeze", "expand_dims"};
static AMPList clear_list = {"gather",
"max_pool_1d",
"max_pool_2d",
"max_pool_3d",
"reshape",
"relu",
"transpose",
"random_mask_like",
"concat",
"pad",
"same_padding",
"tril",
"slice",
"fused_scale_tril",
"identity",
"flatten",
"squeeze",
"expand_dims",
"cast_to_static_shape"};

return clear_list;
}
Expand Down
19 changes: 9 additions & 10 deletions oneflow/core/job_rewriter/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -964,17 +964,16 @@ void AddDiffStaticShapeCast(const OpGraph& op_graph, JobBuilder* job_builder,
LogicalBlobId& diff_lbi = pair.second;
const OpNode* model_op_node = op_graph.OpNode4OpName(lbi.op_name());
int64_t scope_symbol_id = model_op_node->op().op_conf().scope_symbol_id();
OperatorConf cast_to_static_shape_op_conf{};
cast_to_static_shape_op_conf.set_name("System-AutoGrad-StaticShapeCast-" + NewUniqueId());
CastToStaticShapeOpConf* cast_to_static_shape_conf =
cast_to_static_shape_op_conf.mutable_cast_to_static_shape_conf();
cast_to_static_shape_conf->set_in(GenLogicalBlobName(diff_lbi));
cast_to_static_shape_conf->set_out("out");
cast_to_static_shape_op_conf.set_scope_symbol_id(scope_symbol_id);
const auto cast_to_static_shape_op =
user_op::UserOpConfWrapperBuilder("System-AutoGrad-StaticShapeCast-" + NewUniqueId())
.Op("cast_to_static_shape")
.Input("input", GenLogicalBlobName(diff_lbi))
.Output("output")
.ScopeSymbolId(scope_symbol_id)
.Build();
job_builder->AddOps(model_op_node->parallel_desc().parallel_conf(),
{cast_to_static_shape_op_conf});
diff_lbi.set_op_name(cast_to_static_shape_op_conf.name());
diff_lbi.set_blob_name(cast_to_static_shape_conf->out());
{cast_to_static_shape_op.op_conf()});
diff_lbi = GenLogicalBlobId(cast_to_static_shape_op.output("output", 0));
}
}

Expand Down
12 changes: 8 additions & 4 deletions oneflow/core/job_rewriter/prune_cast_to_static_shape_op_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/job_rewriter/job_pass.h"
#include "oneflow/core/register/runtime_blob_desc.h"

Expand Down Expand Up @@ -49,13 +50,16 @@ Maybe<void> PruneCastToStaticShapeOpsPass::Apply(const OpGraph& op_graph,
});
op_graph.ForEachNode([&](const OpNode* op_node) {
const OperatorConf& op_conf = op_node->op().op_conf();
if (!op_conf.has_cast_to_static_shape_conf()) { return; }
if (!op_conf.has_user_conf()) { return; }
const std::string& op_type_name = op_conf.user_conf().op_type_name();
if (op_type_name != "cast_to_static_shape") { return; }
if (!op_conf.ctrl_in_op_name().empty()) { return; }
if (ctrl_in_op_names.find(op_conf.name()) != ctrl_in_op_names.end()) { return; }
if (op_node->in_edges().size() != 1) { return; }
const OpNode* producer = op_node->SoleInEdge()->src_node();
const LogicalBlobId& cast_in_lbi = op_node->op().BnInOp2Lbi("in");
const LogicalBlobId& cast_out_lbi = op_node->op().BnInOp2Lbi("out");
const user_op::UserOpConfWrapper user_op_conf(op_conf);
const LogicalBlobId& cast_in_lbi = GenLogicalBlobId(user_op_conf.input("input", 0));
const LogicalBlobId& cast_out_lbi = GenLogicalBlobId(user_op_conf.output("output", 0));
const OpNode* producer = op_graph.OpNode4OpName(cast_in_lbi.op_name());
const BlobDesc& cast_in_logical_blob_desc = producer->LogicalBlobDesc4Lbi(cast_in_lbi);
if (cast_in_logical_blob_desc.is_dynamic()) { return; }
for (const OpEdge* out_edge : op_node->out_edges()) {
Expand Down
41 changes: 0 additions & 41 deletions oneflow/core/kernel/cast_to_static_shape_kernel.cpp

This file was deleted.

46 changes: 0 additions & 46 deletions oneflow/core/operator/identity_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,50 +198,4 @@ REGISTER_OP(OperatorConf::kCastFromMirroredConf, CastFromMirroredOp);

} // namespace

namespace {

class CastToStaticShapeOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(CastToStaticShapeOp);
CastToStaticShapeOp() = default;
~CastToStaticShapeOp() override = default;

private:
void InitFromOpConf() override {
EnrollInputBn("in");
EnrollOutputBn("out")->set_const_inplace_ibn("in");
}

Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override {
const BlobDesc* in = GetBlobDesc4BnInOp("in");
BlobDesc* out = GetBlobDesc4BnInOp("out");
*out = *in;
out->set_is_dynamic(false);
return Maybe<void>::Ok();
}

Maybe<void> InferBatchAxis(
std::function<OptInt64*(const std::string&)> BatchAxis4BnInOp) const override {
return NaiveInferBatchAxis(BatchAxis4BnInOp);
}

Maybe<void> GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const override {
SbpSignatureBuilder().PartialSum("in").PartialSum("out").Build(
sbp_sig_list->mutable_sbp_signature()->Add());
const int64_t num_axes = JUST(LogicalBlobDesc4Ibn("in")).shape().NumAxes();
FOR_RANGE(int64_t, i, 0, num_axes) {
SbpSignatureBuilder().Split("in", i).Split("out", i).Build(
sbp_sig_list->mutable_sbp_signature()->Add());
}
return Maybe<void>::Ok();
}
};

REGISTER_OP(OperatorConf::kCastToStaticShapeConf, CastToStaticShapeOp);

} // namespace

} // namespace oneflow
6 changes: 0 additions & 6 deletions oneflow/core/operator/op_conf.proto
Original file line number Diff line number Diff line change
Expand Up @@ -561,11 +561,6 @@ message BoxingIdentityOpConf {
required LogicalBlobId lbi = 1;
}

message CastToStaticShapeOpConf {
required string in = 1;
required string out = 2;
}

message BoxingS2SAll2AllPackOpConf {
required LogicalBlobId lbi = 1;
required int64 dst_split_axis = 2;
Expand Down Expand Up @@ -649,7 +644,6 @@ message OperatorConf {
CollectiveBoxingGenericOpConf collective_boxing_generic_conf = 170;
BoxingIdentityOpConf boxing_identity_conf = 171;
TensorListSplitOpConf tensor_list_split_conf = 172;
CastToStaticShapeOpConf cast_to_static_shape_conf = 173;
BoxingS2SAll2AllPackOpConf boxing_s2s_all2all_pack_conf = 174;
BoxingS2SAll2AllUnpackOpConf boxing_s2s_all2all_unpack_conf = 175;
BoxingZerosOpConf boxing_zeros_conf = 176;
Expand Down
54 changes: 54 additions & 0 deletions oneflow/python/ops/array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1934,6 +1934,60 @@ def identity_Job(x: tp.Numpy.Placeholder(shape=(1, 3), dtype=flow.int32),
)


@oneflow_export("cast_to_static_shape")
def cast_to_static_shape(
x: oneflow_api.BlobDesc, name: Optional[str] = None
) -> oneflow_api.BlobDesc:
r"""This operator returns a `Blob` that has identical content and data type to input `Blob`, and whose shape is converted from dynamic to static

Args:
x (oneflow_api.BlobDesc): The input Blob which has dynamic shape.
name (Optional[str], optional): The name for the operation. Defaults to None.

Returns:
oneflow_api.BlobDesc: The result Blob which is identical to input blob but has static shape.

For example:

.. code-block:: python

import oneflow as flow
import numpy as np
import oneflow.typing as tp

@flow.global_function()
def cast_to_static_shape_func(
x: tp.ListNumpy.Placeholder(shape=(3, 3), dtype=flow.float32),
) -> tp.Numpy:
return flow.cast_to_static_shape(x)

x = np.array([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]]).astype(np.float32)

out = cast_to_static_shape_func(x)

# out [[1 1 1]
# [2 2 2]
# [3 3 3]]

"""
if not x.is_dynamic:
return x

if name is None:
name = id_util.UniqueStr("CastToStaticShape_")

op = (
flow.user_op_builder(name)
.Op("cast_to_static_shape")
.Input("input", [x])
.Output("output")
.Build()
)
return op.InferAndTryRun().SoleOutputBlob()


@oneflow_export("squeeze")
def squeeze(
input: oneflow_api.BlobDesc,
Expand Down
127 changes: 127 additions & 0 deletions oneflow/python/test/ops/test_cast_to_static_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from collections import OrderedDict
import oneflow as flow
from test_util import GenArgDict, type_name_to_flow_type, type_name_to_np_type


def _make_cast_to_static_shape_fn(
test_case, shape, data_type, device_type, device_num, compare_diff_fn
):
dtype = type_name_to_flow_type[data_type]
require_grad = dtype is flow.float32

flow.clear_default_session()
if device_type == "gpu":
flow.config.gpu_device_num(device_num)
elif device_type == "cpu":
flow.config.cpu_device_num(device_num)
else:
raise ValueError

assert device_num > 0
func_config = flow.FunctionConfig()
func_config.default_data_type(dtype)
func_config.default_placement_scope(
flow.scope.placement(device_type, "0:0-{}".format(device_num - 1))
)
func_config.default_logical_view(flow.scope.mirrored_view())

@flow.global_function(
type="train" if require_grad else "predict", function_config=func_config
)
def cast_to_static_shape_fn(
x: flow.typing.ListNumpy.Placeholder(shape=shape, dtype=dtype)
) -> flow.typing.ListNumpy:
x_var = flow.get_variable(
name="x_var", shape=(1,), dtype=dtype, initializer=flow.zeros_initializer(),
)
x = x + x_var
y = flow.cast_to_static_shape(x)
test_case.assertFalse(y.is_dynamic)
if require_grad:
flow.watch_diff(x, compare_diff_fn)
flow.optimizer.SGD(
flow.optimizer.PiecewiseConstantScheduler([], [1e-3]), momentum=0
).minimize(y)
return y

return cast_to_static_shape_fn


def _random_input(shape, data_type):
dtype = type_name_to_np_type[data_type]
if data_type == "float32" or data_type == "double":
return np.random.random_sample(shape).astype(dtype)
elif data_type == "int32":
return np.random.randint(low=0, high=100, size=shape).astype(dtype)
else:
raise NotImplementedError


def _check_cast_to_static_shape(test_case, shape, data_type, device_type, device_num):
x = _random_input(shape, data_type)

def comp(x, y):
test_case.assertTrue(np.array_equal(x, y))

def comp_diff(diff):
dx = np.ones(shape)
for d in diff.numpy_list():
test_case.assertTrue(np.array_equal(d, dx))

cast_to_static_shape_fn = _make_cast_to_static_shape_fn(
test_case, shape, data_type, device_type, device_num, comp_diff
)
y = cast_to_static_shape_fn([x] * device_num)

if isinstance(y, list):
for y_ in y:
comp(x, y_)
elif isinstance(y, np.ndarray):
comp(x, y)
else:
raise ValueError


@flow.unittest.skip_unless_1n1d()
class TestCastToStaticShape(flow.unittest.TestCase):
def test_case_1(test_case):
arg_dict = OrderedDict()
arg_dict["shape"] = [(5, 4, 3), (10, 7)]
arg_dict["data_type"] = ["float32", "double", "int32"]
arg_dict["device_type"] = ["gpu", "cpu"]
arg_dict["device_num"] = [1]
for arg in GenArgDict(arg_dict):
_check_cast_to_static_shape(test_case, **arg)


@flow.unittest.skip_unless_1n4d()
class TestCastToStaticShapeParallel(flow.unittest.TestCase):
def test_case_1(test_case):
arg_dict = OrderedDict()
arg_dict["shape"] = [(10,)]
arg_dict["data_type"] = ["float32", "double", "int32"]
arg_dict["device_type"] = ["gpu", "cpu"]
arg_dict["device_num"] = [4]
for arg in GenArgDict(arg_dict):
_check_cast_to_static_shape(test_case, **arg)


if __name__ == "__main__":
unittest.main()
Loading