Skip to content

Support dispensable inputs for eager final state codegen #39743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ def IntermediateValidationCheck(intermediate_outputs, forward_returns_list):
assert pos in intermediate_positions


def ParseDispensable(string):
# string: "X, Y"
return [v.strip() for v in string.split(",")]


def ParseIntermediate(string):
return [v.strip() for v in string.split(",")]

Expand Down Expand Up @@ -596,11 +601,11 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
return node_definition_str


def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list):
def GenerateNodeCreationCodes(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs):
# fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
Expand Down Expand Up @@ -674,10 +679,17 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
# SetTensorWrappers
set_tensor_wrappers_list = []
for name, (_, is_fwd_input, _) in backward_fwd_input_map.items():
is_optional = (name in optional_inputs)
if is_fwd_input:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
if is_optional:
set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, false);"
if is_optional:
set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, false);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, false);"
set_tensor_wrappers_list.append(set_tensor_wrappers)
set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list)

Expand Down Expand Up @@ -762,11 +774,12 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
return node_creation_str


def GenerateForwardDefinition(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, intermediate_outputs):
def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list,
optional_inputs, intermediate_outputs):
# fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
Expand All @@ -775,6 +788,7 @@ def GenerateForwardDefinition(
# backward_grad_input_map = { "name" : [type, fwd_position, orig_position] ...}
# backward_grad_output_map = { "name" : [type, fwd_position, orig_position] ...}
# backward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# optional_inputs = ["name0", ...]

# Get Function Args
num_inputs = len(forward_attrs_list) + len(forward_inputs_position_map.keys(
Expand All @@ -784,17 +798,18 @@ def GenerateForwardDefinition(
inputs_call_list = ["" for i in range(num_inputs)]
for name, (ttype, pos) in forward_inputs_position_map.items():
inputs_call_list[pos] = f"{name}"
is_optional = (name in optional_inputs)
if IsPlainTensorType(ttype):
inputs_args_definition_list[
pos] = f"const paddle::experimental::Tensor& {name}"
inputs_args_declaration_list[
pos] = f"const paddle::experimental::Tensor& {name}"
if is_optional:
arg_str = f"const paddle::optional<paddle::experimental::Tensor>& {name}"
else:
arg_str = f"const paddle::experimental::Tensor& {name}"
else:
assert IsVectorTensorType(ttype)
inputs_args_definition_list[
pos] = f"const std::vector<paddle::experimental::Tensor>& {name}"
inputs_args_declaration_list[
pos] = f"const std::vector<paddle::experimental::Tensor>& {name}"
arg_str = f"const std::vector<paddle::experimental::Tensor>& {name}"

inputs_args_definition_list[pos] = arg_str
inputs_args_declaration_list[pos] = arg_str

for name, atype, default_val, pos in forward_attrs_list:
inputs_call_list[pos] = name
Expand Down Expand Up @@ -849,7 +864,7 @@ def GenerateForwardDefinition(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list)
backward_grad_output_map, backward_attrs_list, optional_inputs)

FORWARD_FUNCTION_TEMPLATE = """
{} {}({}) {{
Expand Down Expand Up @@ -1053,6 +1068,12 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
assert 'args' in bwd_api.keys()
assert 'output' in bwd_api.keys()
assert 'forward' in bwd_api.keys()

# Parse Dispensable Inputs
optional_inputs = []
if 'optional' in fwd_api.keys():
optional_inputs = ParseDispensable(fwd_api['optional'])

bwd_forward_str = bwd_api['forward']
bwd_args_str = bwd_api['args']
bwd_returns_str = bwd_api['output']
Expand Down Expand Up @@ -1128,7 +1149,8 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, intermediate_outputs)
backward_grad_output_map, backward_attrs_list, optional_inputs,
intermediate_outputs)
print("Generated Forward Definition: ", forward_definition_str)
print("Generated Forward Declaration: ", forward_declaration_str)
forward_definition_str += definition_declaration_pair[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import os
import argparse
from eager_gen import ReadFwdFile, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap
from eager_gen import ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap

atype_to_parsing_function = {
"bool": "CastPyArg2Boolean",
Expand Down Expand Up @@ -70,10 +70,12 @@ def FindParsingFunctionFromAttributeType(atype):


def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map,
forward_attrs_list, forward_outputs_position_map):
forward_attrs_list, forward_outputs_position_map,
optional_inputs):
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
# forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# optional_inputs = [name0, ...]

# Get EagerTensor from args
# Get dygraph function call args
Expand All @@ -82,7 +84,14 @@ def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map,
dygraph_function_call_list = ["" for i in range(num_args)]
get_eager_tensor_str = ""
for name, (ttype, pos) in forward_inputs_position_map.items():
get_eager_tensor_str += f" auto& {name} = GetTensorFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n"
is_optional = (name in optional_inputs)
if IsVectorTensorType(ttype):
get_eager_tensor_str += f" auto {name} = GetTensorListFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n"
else:
if is_optional:
get_eager_tensor_str += f" auto {name} = GetOptionalTensorFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n"
else:
get_eager_tensor_str += f" auto {name} = GetTensorFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n"
dygraph_function_call_list[pos] = f"{name}"

parse_attributes_str = ""
Expand Down Expand Up @@ -267,6 +276,11 @@ def GeneratePythonCFile(filepath, python_c_str):
fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output']

# Parse Dispensable Inputs
optional_inputs = []
if 'optional' in fwd_api.keys():
optional_inputs = ParseDispensable(fwd_api['optional'])

# Collect Original Forward Inputs/Outputs and then perform validation checks
forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForward(
fwd_args_str, fwd_returns_str)
Expand All @@ -283,7 +297,7 @@ def GeneratePythonCFile(filepath, python_c_str):

python_c_function_str, python_c_function_reg_str = GeneratePythonCFunction(
fwd_api_name, forward_inputs_position_map, forward_attrs_list,
forward_outputs_position_map)
forward_outputs_position_map, optional_inputs)
python_c_function_list.append(python_c_function_str)
python_c_function_reg_list.append(python_c_function_reg_str)
print("Generated Python-C Function: ", python_c_function_str)
Expand Down
26 changes: 26 additions & 0 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,32 @@ PyObject* ToPyObject(
return dict;
}

// For Final State Dygraph,
// We directly use paddle::optional(Tensor) as dispensable Tensor
paddle::optional<paddle::experimental::Tensor> GetOptionalTensorFromArgs(
const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable) {
PyObject* obj = PyTuple_GET_ITEM(args, arg_idx);

if (PyTuple_Check(obj)) {
obj = PyTuple_GET_ITEM(obj, 0);
}

if (obj == nullptr || obj == Py_None) {
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got None",
op_type, arg_name, arg_idx));
}
return {};
}

return paddle::make_optional<paddle::experimental::Tensor>(
reinterpret_cast<TensorObject*>(obj)->tensor);
}

// For Intermediate State Dygraph,
// we use an uninitialized Tensor to represent dispensable Tensor
paddle::experimental::Tensor& GetTensorFromArgs(const std::string& op_type,
const std::string& arg_name,
PyObject* args, ssize_t arg_idx,
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/pybind/eager_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,15 @@ PyObject* ToPyObject(const std::tuple<Args...>& out) {
return result;
}

paddle::optional<paddle::experimental::Tensor> GetOptionalTensorFromArgs(
const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable = false);

paddle::experimental::Tensor& GetTensorFromArgs(const std::string& op_type,
const std::string& arg_name,
PyObject* args, ssize_t arg_idx,
bool dispensable = false);

std::vector<paddle::experimental::Tensor> GetTensorListFromArgs(
const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable = false);
Expand All @@ -102,6 +107,7 @@ paddle::experimental::Tensor* GetTensorPtrFromArgs(const std::string& op_type,
PyObject* args,
ssize_t arg_idx,
bool dispensable = false);

std::vector<paddle::experimental::Tensor*> GetTensorPtrListFromArgs(
const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable = false);
Expand Down