Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pir]Fix Value eq error when using set #58896

Merged
merged 53 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
6539a97
[Pir]Fix Value eq error when using set
Nov 10, 2023
3899a7d
Fix iter
Nov 10, 2023
7f6c298
Refine code
Nov 10, 2023
0793a38
add ValueDict
zrr1999 Nov 10, 2023
f8abfa3
fix
zrr1999 Nov 13, 2023
6e5947b
update valueset
zrr1999 Nov 13, 2023
09f5eac
update dict
zrr1999 Nov 13, 2023
ad7461d
fix bug
zrr1999 Nov 14, 2023
531621c
improve
zrr1999 Nov 14, 2023
6b252fc
fix
zrr1999 Nov 14, 2023
7eb2685
fix contains
zrr1999 Nov 14, 2023
d7792bd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 15, 2023
bc174c9
Fix set(Value) or dict[Value] = xx code
Nov 15, 2023
8a6f0a7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 16, 2023
d7b38e8
Fix ut
Nov 16, 2023
963d0f8
Fix double grad ut
Nov 16, 2023
d299a50
Fix ut
Nov 17, 2023
8bae46b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 24, 2023
ba7a2cd
Forbid opresult hash
Nov 24, 2023
4a50818
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 26, 2023
ace1ae1
Fix set and dict
Nov 26, 2023
cd3e2ea
Fix dy2st
Nov 26, 2023
470a2aa
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 7, 2023
faefbe8
Refine code
Dec 7, 2023
526eecd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 7, 2023
78cbe1b
Forbid eq
Dec 7, 2023
9be9129
Fix map and value_eq
Dec 7, 2023
176e09b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 8, 2023
ea46d6b
Fix None hash
Dec 8, 2023
dcaf6c4
Fix decomp
Dec 8, 2023
1df44e2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 11, 2023
55cbfcc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 12, 2023
b8b6416
Refine value set/dict
Dec 12, 2023
16bc3c6
Add hash
Dec 12, 2023
3b0a7cd
fix clone program, return a assciated array.
2742195759 Dec 12, 2023
1a7767d
Merge commit 'refs/pull/58896/head' of https://github.com/PaddlePaddl…
2742195759 Dec 12, 2023
7ab860c
Fix iter
Dec 12, 2023
31686de
Fix code
Dec 13, 2023
294abe4
Fix backward
Dec 13, 2023
b168af9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 13, 2023
96d8cf3
Fix decomp and add copy()
Dec 13, 2023
e3637da
Format code
Dec 13, 2023
ef65985
Fix ut
Dec 14, 2023
7e6422c
Fix layer params set
Dec 14, 2023
a43cd6e
Fix prim op test
Dec 14, 2023
cf1ce82
Fix named_parameters
Dec 14, 2023
88b3e9f
Support value __eq__
Dec 15, 2023
4ec0f71
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 15, 2023
2611c9f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 15, 2023
7b8cf7e
Fix test cond ==
Dec 15, 2023
6af0553
Add ut for value set/dict
Dec 18, 2023
b8953f2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 18, 2023
4464870
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,8 @@ void BindValue(py::module *m) {
[](Value &self, Value &op_value) {
self.ReplaceAllUsesWith(op_value);
})
.def("__eq__", &Value::operator==)
.def("__eq__",
.def("is_same", &Value::operator==)
.def("is_same",
[](Value &self, OpResult &other) {
return self.impl() == other.Value::impl();
})
Expand Down Expand Up @@ -664,8 +664,8 @@ void BindOpResult(py::module *m) {
OVERRIDE_COMPARE_OP_FOR_EACH(__gt__, greater_than);
OVERRIDE_COMPARE_OP_FOR_EACH(__ge__, greater_equal);

op_result.def("__eq__", &OpResult::operator==)
.def("__eq__",
op_result.def("is_same", &OpResult::operator==)
.def("is_same",
[](OpResult &self, Value &other) {
return self.Value::impl() == other.impl();
})
Expand Down
90 changes: 80 additions & 10 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,77 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
from typing import Any


class ValueDict:
def __init__(self, iter=None, *, default_factory=None):
self._items: list[tuple[Any, Any]] = []
self._default_factory = default_factory
if iter is not None:
for key, val in iter:
self[key] = val

def update(self, other_dict):
for key, val in other_dict:
self[key] = val

def keys(self):
for key, _ in self._items:
yield key

def values(self):
for _, val in self._items:
yield val

def items(self):
yield from self._items

def __setitem__(self, other_key, other_val: Any):
if self.__contains__(other_key):
for i, (key, val) in enumerate(self._items):
if hash(key) == hash(other_key) and key.is_same(other_key):
self._items[i] = (other_key, other_val)
break
else:
self._items.append((other_key, other_val))

def __getitem__(self, other_key):
for key, val in self._items:
if hash(key) == hash(other_key) and key.is_same(other_key):
return val

if self._default_factory is not None:
val = self._default_factory()
self._items.append((other_key, val))
return val
else:
return None

def __and__(self, other_dict: "ValueDict"):
ret = ValueDict()
for key, val in self._items:
if key in other_dict:
ret[key] = val
return ret

def __or__(self, other_dict: "ValueDict"):
return ValueDict(self._items + other_dict._items)

def __bool__(self):
return bool(self._items)

def __len__(self):
return len(self._items)

def __iter__(self):
return self.keys()

def __contains__(self, other_key):
for key, _ in self._items:
if hash(key) == hash(other_key) and key.is_same(other_key):
return True
return False


class State:
Expand All @@ -25,21 +95,21 @@ class State:
def __init__(self, program):
self.program = program
# opresult -> list(list(opresult))
self.value_to_valuegrad = collections.defaultdict(list)
self.value_to_sumvaluegrad = collections.defaultdict(list)
self.value_to_valuegrad = ValueDict(default_factory=list)
self.value_to_sumvaluegrad = ValueDict(default_factory=list)
# operation -> list(operation)
self.op_to_opgrad = collections.defaultdict(list)
self.op_to_opgrad = ValueDict(default_factory=list)

# opresult -> list(opresult)
self.valuegrad_to_value = collections.defaultdict(list)
self.sumvaluegrad_to_value = collections.defaultdict(list)
self.valuegrad_to_value = ValueDict(default_factory=list)
self.sumvaluegrad_to_value = ValueDict(default_factory=list)
# operation -> list(operation)
self.opgrad_to_op = collections.defaultdict(list)
self.opgrad_to_op = ValueDict(default_factory=list)

def turn_map(self) -> None:
self.valuegrad_to_value = collections.defaultdict(list)
self.sumvaluegrad_to_value = collections.defaultdict(list)
self.opgrad_to_op = collections.defaultdict(list)
self.valuegrad_to_value = ValueDict(default_factory=list)
self.sumvaluegrad_to_value = ValueDict(default_factory=list)
self.opgrad_to_op = ValueDict(default_factory=list)

for k, v in self.value_to_valuegrad.items():
if v != []:
Expand Down
84 changes: 63 additions & 21 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,47 @@
__all__ = ['grad', 'calc_gradient', 'calc_gradient_helper']


class ValueSet:
def __init__(self, iter=None):
self._values: list = []
if iter is not None:
for val in iter:
self.add(val)

def add(self, other_val):
if not self.__contains__(other_val):
self._values.append(other_val)

def update(self, other_set):
for val in other_set:
self.add(val)

def __and__(self, other_set):
ret = ValueSet()
for val in self._values:
if val in other_set:
ret.add(val)
return ret

def __or__(self, other_set):
return ValueSet(self._values + other_set._values)

def __bool__(self):
return bool(self._values)

def __len__(self):
return len(self._values)

def __iter__(self):
return iter(self._values)

def __contains__(self, other_val):
for value in self._values:
if hash(value) == hash(other_val) and value.is_same(other_val):
return True
return False


def check_type(input, input_name, expected_type, op_name, extra_message=''):
if not isinstance(input, expected_type):
raise TypeError(
Expand Down Expand Up @@ -124,7 +165,7 @@ def prepare_grad_outputs(grad_outputs, outputs, state):
complete_outputs = outputs
complete_gradoutputs = grad_outputs

visited_output = set()
visited_output = ValueSet()
for output in outputs:
if output in visited_output:
continue
Expand Down Expand Up @@ -157,7 +198,7 @@ def prepare_grad_outputs(grad_outputs, outputs, state):

def some_in_set(value_list, value_set):
def operand2value(values):
value_set = set()
value_set = ValueSet()
for item in values:
if isinstance(item, paddle.pir.OpOperand):
value_set.add(item.source())
Expand Down Expand Up @@ -245,7 +286,7 @@ def update_no_grad_set_after_prune(
from inputs to outputs add value not in the path to no_grad_set,
from outputs to inputs add value not in the path to no_grad_set,
'''
inputs_set = set(inputs)
inputs_set = ValueSet(inputs)
if inputs_set:
for op in block.ops:
if some_in_set(op.operands_source(), inputs_set):
Expand All @@ -258,12 +299,12 @@ def update_no_grad_set_after_prune(
if value not in inputs_set:
no_grad_set.add(value)

outputs_set = set(outputs)
no_grad_set_tmp = set()
outputs_set = ValueSet(outputs)
no_grad_set_tmp = ValueSet()
for op in reversed(effective_forward_ops):
for output in op.results():
if output not in outputs_set and not some_in_set(
[output], set(op.operands_source())
[output], ValueSet(op.operands_source())
):
no_grad_set_tmp.add(output)

Expand Down Expand Up @@ -317,7 +358,7 @@ def inverse_sort_op(ops):


def append_backward_ops(
block, effective_forward_ops, no_grad_set, backward_ops, state
block, effective_forward_ops, no_grad_set, backward_ops, state: State
):
'''
add grad_op in order of topological inverse sort
Expand Down Expand Up @@ -577,7 +618,7 @@ def update_input_grad_map(op, input_grads):


def create_backward_prune_set(inputs, outputs, no_grad_set, state):
outputs_set = set()
outputs_set = ValueSet()
for input_ in inputs:
if not input_.use_empty():
for item in input_.first_use().owner().operands_source():
Expand All @@ -586,18 +627,18 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state):
else:
logging.warning("input privided by inputs has no use")

inputs_set = set()
inputs_set = ValueSet()
for output in outputs:
if state.value_to_valuegrad[output] != []:
inputs_set.add(state.value_to_valuegrad[output][0][0])
inputs_set_tmp = set()
inputs_set_tmp = ValueSet()
for out_grad in inputs_set:
if not out_grad.use_empty():
for item in out_grad.first_use().owner().operands_source():
inputs_set_tmp.add(item)
inputs_set.update(inputs_set_tmp)

no_gradvar_set = set() # grad_value of value in no_grad_set
no_gradvar_set = ValueSet() # grad_value of value in no_grad_set
for key in state.value_to_valuegrad:
if key in no_grad_set and state.value_to_valuegrad[key] != []:
no_gradvar_set.add(state.value_to_valuegrad[key][0][0])
Expand Down Expand Up @@ -640,8 +681,8 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
grad_outputs, outputs, state
)

inputs_set = set(inputs)
outputs_set = set(complete_outputs)
inputs_set = ValueSet(inputs)
outputs_set = ValueSet(complete_outputs)
effective_forward_ops, _ = prune_ops(
block.ops, inputs_set, outputs_set, no_grad_set
)
Expand Down Expand Up @@ -690,7 +731,7 @@ def calc_gradient(outputs, inputs, grad_outputs, no_grad_set):
be: (1) a Value filled with 1 when the i-th element of `grad_outputs`
is None; (2) the i-th element of `grad_outputs` when the i-th element of
`grad_outputs` is a Value. Default None.
no_grad_set (set(Value), optional):
no_grad_set (list(Value)|tuple(Value), optional):
the Values whose gradients are not needed to compute. Default None.

Return:
Expand All @@ -701,7 +742,10 @@ def calc_gradient(outputs, inputs, grad_outputs, no_grad_set):
"""
# record input value and its gradient (Value to Value)
input_to_inputgrad_map = calc_gradient_helper(
outputs, inputs, grad_outputs=grad_outputs, no_grad_set=no_grad_set
outputs,
inputs,
grad_outputs=grad_outputs,
no_grad_set=ValueSet(no_grad_set),
)

inputgrad = []
Expand Down Expand Up @@ -764,7 +808,7 @@ def grad(
`inputs` are unreachable in the graph (i.e., their gradients are None),
error would be raised if allow_unused=False, or None would be returned as
their gradients if allow_unused=True. Default False.
no_grad_vars (Value|list(Value)|tuple(Value)|set(Value), optional):
no_grad_vars (Value|list(Value)|tuple(Value), optional):
the Values whose gradients are not needed to compute. Default None.

Returns:
Expand Down Expand Up @@ -794,18 +838,16 @@ def grad(
check_type(
no_grad_vars,
'no_grad_vars',
((paddle.pir.Value, paddle.pir.OpResult), list, tuple, set, type(None)),
((paddle.pir.Value, paddle.pir.OpResult), list, tuple, type(None)),
'paddle.autograd.ir_backward.grad',
)
outputs = _as_list(outputs)
inputs = _as_list(inputs)
grad_outputs = _as_list(grad_outputs)
if no_grad_vars is None:
no_grad_set = set()
elif no_grad_vars is not set:
no_grad_set = set(no_grad_vars)
no_grad_set = ValueSet()
else:
no_grad_set = no_grad_vars
no_grad_set = ValueSet(no_grad_vars)

input_grad = calc_gradient(outputs, inputs, grad_outputs, no_grad_set)

Expand Down
7 changes: 1 addition & 6 deletions python/paddle/base/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2653,7 +2653,6 @@ def gradients(targets, inputs, target_gradients=None, no_grad_set=None):
(paddle.pir.Value, paddle.pir.OpResult),
list,
tuple,
set,
type(None),
),
'paddle.autograd.ir_backward.grad',
Expand All @@ -2662,11 +2661,7 @@ def gradients(targets, inputs, target_gradients=None, no_grad_set=None):
inputs = _as_list(inputs)
target_gradients = _as_list(target_gradients)
if no_grad_set is None:
no_grad_set = set()
elif no_grad_set is not set:
no_grad_set = set(no_grad_set)
else:
no_grad_set = no_grad_set
no_grad_set = []
from paddle.autograd.ir_backward import (
calc_gradient as pir_calc_gradient,
)
Expand Down