Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph/block io check #5803

Merged
merged 20 commits into from
Aug 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/oneflow/framework/graph_build_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ def __enter__(self):
c_api_util.CurJobBuildAndInferCtx_SetJobConf(self._job_conf)

def __exit__(self, exc_type, exc_val, exc_tb):
oneflow._oneflow_internal.CurJobBuildAndInferCtx_Complete()
oneflow._oneflow_internal.JobBuildAndInferCtx_Close()
if exc_type is None:
oneflow._oneflow_internal.CurJobBuildAndInferCtx_Complete()
oneflow._oneflow_internal.JobBuildAndInferCtx_Close()
return True
else:
oneflow._oneflow_internal.JobBuildAndInferCtx_Close()
return False


Expand Down
3 changes: 3 additions & 0 deletions python/oneflow/framework/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@


def _tensor_numpy(eager_local_tensor):
assert (
not eager_local_tensor.is_lazy
), "tensor.numpy() is not allowed to called in nn.Graph.build(*args) or called by lazy tensor."
if eager_local_tensor.dtype == flow.tensor_buffer:
shapes, dtypes = eager_local_tensor._tensor_buffer_shapes_and_dtypes
tensors = flow.tensor_buffer_to_list_of_tensors(
Expand Down
327 changes: 241 additions & 86 deletions python/oneflow/nn/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from oneflow.nn.graph_optimizer import OptimizerConfig, VariableConfig
from oneflow.nn.module import Module
from oneflow.nn.optimizer.optimizer import Optimizer
from oneflow.nn.util import add_indent
from oneflow.nn.util import add_indent, sys_exc_error_msg, list_to_func_return


class Graph(object):
Expand Down Expand Up @@ -118,118 +118,273 @@ def _generate_optimizer_and_variable_configs(self):
)

def _compile(self, *args):
assert not self._is_compiled, (
"nn.Graph " + self._name + " has already been compiled."
)
if self._debug:
print(self._shallow_repr() + " start graph construting.")
# Build forward graph
try:
if self._debug:
print(self._shallow_repr() + " start building forward graph.")
assert not self._is_compiled, (
"nn.Graph " + self._name + " has already been compiled."
)

eager_outputs = self._build_forward_graph(*args)

if self._debug:
print(self._shallow_repr() + " end building forward graph.")
except:
print(
"[ERROR]"
+ self._shallow_repr()
+ " build forward graph got error: "
+ sys_exc_error_msg()
)
raise

# Complie and init Runtime
try:
if self._debug:
print(self._shallow_repr() + " start compiling and init graph runtime.")

self._c_nn_graph.complie_and_init_runtime()

if self._debug:
print(self._shallow_repr() + " end compiling and init graph rumtime.")
except:
print(
"[ERROR]"
+ self._shallow_repr()
+ " compiling and initialing graph runtime got error : ",
sys_exc_error_msg(),
)
raise

self._is_compiled = True
return eager_outputs

def _build_forward_graph(self, *args):
self._generate_optimizer_and_variable_configs()

session = session_ctx.GetDefaultSession()
assert type(session) is MultiClientSession
session.TryInit()

with graph_build_util.graph_build_context(self.config.proto, session):
# Deal with input
lazy_args = []
lazy_arg_op_names = []
for idx, arg in enumerate(args):
op_name = "_" + self.name + "-input_" + str(idx)
lazy_args.append(graph_build_util.build_graph_input_arg(op_name, arg))
lazy_arg_op_names.append(op_name)
in_str = "(INPUT:" + op_name + ":" + arg._meta_repr() + ")"
self._args_repr.append(in_str)
if self._debug:
print(in_str)
# Deal with inputs
arg_op_names, lazy_args, self._args_repr = self._build_io(
"input", graph_build_util.build_graph_input_arg, *args
)

# Deal with parameter and buffer
state_op_names = []
state_tensors = []
for state_block in self._state():
op_name = state_block.name_prefix + state_block.name
state_tensor = state_block.origin
state_op_names.append(op_name)
state_tensors.append(state_tensor)
if state_block.type == BlockType.PARAMETER:
state_config = self._variables_conf[state_block.origin]
else:
state_config = None
state_block.set_lazy_origin_builder(
partial(
graph_build_util.build_graph_state,
op_name,
state_tensor,
state_config,
)
)
self._variables = convert_to_tensor_tuple(state_tensors)
state_op_names, self._states_tensor_tuple = self._build_states()

# Deal with module in self.build(*args)
outputs = self.build(*lazy_args)
if outputs is None:
outputs = tuple()
elif isinstance(outputs, (list, tuple)):
outputs = tuple(outputs)
elif isinstance(outputs, Tensor):
outputs = (outputs,)
elif isinstance(outputs, TensorTuple):
pass
else:
raise RuntimeError(f"invalid outputs with type {type(outputs)}")

eager_outputs = []
eager_output_op_names = []
for idx, out in enumerate(outputs):
op_name = "_" + self.name + "-output_" + str(idx)
eager_outputs.append(graph_build_util.build_graph_output(op_name, out))
eager_output_op_names.append(op_name)

out_str = "(OUTPUT:" + op_name + ":" + out._meta_repr() + ")"
self._outs_repr.append(out_str)
if self._debug:
print(out_str)

if len(eager_outputs) == 0:
eager_outputs = None
elif len(eager_outputs) == 1:
eager_outputs = eager_outputs[0]
else:
eager_outputs = tuple(eager_outputs)

self._outputs = convert_to_tensor_tuple(eager_outputs)
self._eager_outputs = eager_outputs
# Deal with outputs
if not (type(outputs) is tuple or type(outputs) is list):
if outputs is None:
outputs = ()
else:
outputs = (outputs,)
output_op_names, self._eager_outputs, self._outs_repr = self._build_io(
"output", graph_build_util.build_graph_output, *outputs
)
self._outputs_tensor_tuple = convert_to_tensor_tuple(
self._flatten_io("output", *self._eager_outputs)
)
self._eager_outputs = list_to_func_return(self._eager_outputs)

# Register input/output/variable to _c_nn_graph
self._c_nn_graph.register_input_op_names(lazy_arg_op_names)
self._c_nn_graph.register_output_op_names(eager_output_op_names)
self._c_nn_graph.register_input_op_names(arg_op_names)
self._c_nn_graph.register_output_op_names(output_op_names)
self._c_nn_graph.register_variable_op_names_and_tensors(
state_op_names, self._variables
state_op_names, self._states_tensor_tuple
)

# Save job proto for debug
self._job_proto = c_api_util.GetCurrentJob()

# Complie and init Runtime
self._c_nn_graph.complie_and_init_runtime()
self._is_compiled = True
if self._debug:
print(self._shallow_repr() + " end graph construting.")
return eager_outputs
return self._eager_outputs

def _launch(self, *args):
# oneflow._oneflow_internal.eager.multi_client.Sync() NOTE(chengcheng): Need Sync?
oneflow._oneflow_internal.nn.graph.RunLazyNNGraph(
convert_to_tensor_tuple(args),
self._outputs,
self._variables,
self._c_nn_graph,
)
def _run(self, *args):
try:
flattened_eager_args = self._flatten_io("input", *args)
# oneflow._oneflow_internal.eager.multi_client.Sync() NOTE(chengcheng): Need Sync?
oneflow._oneflow_internal.nn.graph.RunLazyNNGraph(
convert_to_tensor_tuple(flattened_eager_args),
self._outputs_tensor_tuple,
self._states_tensor_tuple,
self._c_nn_graph,
)
except:
print(
"[ERROR]"
+ self._shallow_repr()
+ " run got error : "
+ sys_exc_error_msg()
)
raise
return self._eager_outputs

def __call__(self, *args):
if not self._is_compiled:
self._compile(*args)
return self._launch(*args)

return self._run(*args)

def _build_io(self, io_type, build_func, *args):
assert io_type in ("input", "output")
io_type_upper = io_type.upper()
build_args = []
op_names = []
args_repr = []

def build_tensor_or_none(tensor, name, repr_str):
assert tensor is None or (isinstance(tensor, Tensor))
if isinstance(tensor, Tensor):
build_arg = build_func(name, tensor)
op_names.append(name)
else:
build_arg = None

args_repr.append(repr_str)
if self._debug:
print(repr_str)
return build_arg

for idx, arg in enumerate(args):
if isinstance(arg, Tensor) or arg is None:
if arg is None:
name, repr_str = self._io_item_check_and_gen(
arg, None, io_type, idx
)
else:
name, repr_str = self._io_item_check_and_gen(
arg, Tensor, io_type, idx
)
build_args.append(build_tensor_or_none(arg, name, repr_str))
elif isinstance(arg, (TensorTuple, list)):
if isinstance(arg, TensorTuple):
seq_args = TensorTuple()
else:
seq_args = list()
for i in range(len(arg)):
name, repr_str = self._io_item_check_and_gen(
arg[i], Tensor, io_type, idx, i
)
seq_args.append(build_tensor_or_none(arg[i], name, repr_str))
build_args.append(seq_args)
else:
self._io_item_check_and_gen(arg, Tensor, io_type, idx)

return op_names, build_args, args_repr

def _flatten_io(self, io_type, *args):
assert isinstance(args, tuple)
flattened_args = []
for idx, arg in enumerate(args):
if isinstance(arg, Tensor):
flattened_args.append(arg)
elif isinstance(arg, (TensorTuple, list)):
for i in range(len(arg)):
self._io_item_check(arg[i], Tensor, io_type, idx, i)
flattened_args.append(arg[i])
else:
self._io_item_check(arg, None, io_type, idx)
return flattened_args

def _io_item_check(self, item, expect_type, io_type, idx, second_idx=None):
if expect_type is None and item is None:
return
elif expect_type is not None and isinstance(item, expect_type):
return
else:
assert io_type in ("input", "output")
name = (
"_"
+ self.name
+ "-"
+ io_type
+ "_"
+ str(idx)
+ ("" if second_idx is None else "_" + str(second_idx))
)
repr_str = (
"[ERROR](" + io_type.upper() + ":" + name + ":" + str(type(item)) + ")"
)
print(repr_str)
raise NotImplementedError(
"nn.Graph.build()'s input/output only support types: Tensor/TensorTuple/list(Tensor)/None."
)

def _io_item_check_and_gen(self, item, expect_type, io_type, idx, second_idx=None):
assert io_type in ("input", "output")
name = (
"_"
+ self.name
+ "-"
+ io_type
+ "_"
+ str(idx)
+ ("" if second_idx is None else "_" + str(second_idx))
)
if expect_type is None and item is None:
repr_str = (
"[WARNING]("
+ io_type.upper()
+ ":"
+ name
+ ":"
+ str(type(item))
+ ")"
)
return name, repr_str
elif expect_type is not None and isinstance(item, expect_type):
if isinstance(item, Tensor):
repr_str = (
"(" + io_type.upper() + ":" + name + ":" + item._meta_repr() + ")"
)
else:
repr_str = (
"[WARNING]("
+ io_type.upper()
+ ":"
+ name
+ ":"
+ str(type(item))
+ ")"
)
return name, repr_str
else:
repr_str = (
"[ERROR](" + io_type.upper() + ":" + name + ":" + str(type(item)) + ")"
)
print(repr_str)
raise NotImplementedError(
"nn.Graph.build()'s input/output only support types: Tensor/TensorTuple/list(Tensor)/None."
)

def _build_states(self):
state_op_names = []
state_tensors = []
for state_block in self._state():
op_name = state_block.name_prefix + state_block.name
state_tensor = state_block.origin
state_op_names.append(op_name)
state_tensors.append(state_tensor)
if state_block.type == BlockType.PARAMETER:
state_config = self._variables_conf[state_block.origin]
else:
state_config = None
state_block.set_lazy_origin_builder(
partial(
graph_build_util.build_graph_state,
op_name,
state_tensor,
state_config,
)
)
state_tensor_tuple = convert_to_tensor_tuple(state_tensors)
return state_op_names, state_tensor_tuple

def _add_block(self, name: str, module: Module = None) -> None:
r"""Adds a module to the current graph as a block.
Expand Down
Loading