Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pir]Fix Value eq error when using set #58896

Merged
merged 53 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
6539a97
[Pir]Fix Value eq error when using set
Nov 10, 2023
3899a7d
Fix iter
Nov 10, 2023
7f6c298
Refine code
Nov 10, 2023
0793a38
add ValueDict
zrr1999 Nov 10, 2023
f8abfa3
fix
zrr1999 Nov 13, 2023
6e5947b
update valueset
zrr1999 Nov 13, 2023
09f5eac
update dict
zrr1999 Nov 13, 2023
ad7461d
fix bug
zrr1999 Nov 14, 2023
531621c
improve
zrr1999 Nov 14, 2023
6b252fc
fix
zrr1999 Nov 14, 2023
7eb2685
fix contains
zrr1999 Nov 14, 2023
d7792bd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 15, 2023
bc174c9
Fix set(Value) or dict[Value] = xx code
Nov 15, 2023
8a6f0a7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 16, 2023
d7b38e8
Fix ut
Nov 16, 2023
963d0f8
Fix double grad ut
Nov 16, 2023
d299a50
Fix ut
Nov 17, 2023
8bae46b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 24, 2023
ba7a2cd
Forbid opresult hash
Nov 24, 2023
4a50818
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 26, 2023
ace1ae1
Fix set and dict
Nov 26, 2023
cd3e2ea
Fix dy2st
Nov 26, 2023
470a2aa
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 7, 2023
faefbe8
Refine code
Dec 7, 2023
526eecd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 7, 2023
78cbe1b
Forbid eq
Dec 7, 2023
9be9129
Fix map and value_eq
Dec 7, 2023
176e09b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 8, 2023
ea46d6b
Fix None hash
Dec 8, 2023
dcaf6c4
Fix decomp
Dec 8, 2023
1df44e2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 11, 2023
55cbfcc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 12, 2023
b8b6416
Refine value set/dict
Dec 12, 2023
16bc3c6
Add hash
Dec 12, 2023
3b0a7cd
fix clone program, return a assciated array.
2742195759 Dec 12, 2023
1a7767d
Merge commit 'refs/pull/58896/head' of https://github.com/PaddlePaddl…
2742195759 Dec 12, 2023
7ab860c
Fix iter
Dec 12, 2023
31686de
Fix code
Dec 13, 2023
294abe4
Fix backward
Dec 13, 2023
b168af9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 13, 2023
96d8cf3
Fix decomp and add copy()
Dec 13, 2023
e3637da
Format code
Dec 13, 2023
ef65985
Fix ut
Dec 14, 2023
7e6422c
Fix layer params set
Dec 14, 2023
a43cd6e
Fix prim op test
Dec 14, 2023
cf1ce82
Fix named_parameters
Dec 14, 2023
88b3e9f
Support value __eq__
Dec 15, 2023
4ec0f71
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 15, 2023
2611c9f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 15, 2023
7b8cf7e
Fix test cond ==
Dec 15, 2023
6af0553
Add ut for value set/dict
Dec 18, 2023
b8953f2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 18, 2023
4464870
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improve
  • Loading branch information
zrr1999 committed Nov 14, 2023
commit 531621c99a2cea6aebb97071ce7e05e812f0fa1d
59 changes: 53 additions & 6 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,41 @@
from __future__ import annotations

import collections
from collections.abc import Sequence
from typing import Any


class ValueInDict:
class ValueWrapper:
def __init__(self, value) -> None:
if isinstance(value, ValueWrapper):
value = value.value
self.value = value

def __hash__(self) -> int:
return hash(self.value)

def __eq__(self, other) -> bool:
if isinstance(other, ValueInDict):
if isinstance(other, ValueWrapper):
other = other.value
return self.value.is_same(other)


class ValueDict:
def __init__(
self,
iter: dict[ValueInDict, Any] | None = None,
iter: dict[ValueWrapper, Any] | None = None,
*,
default_factory=None,
):
self._items: dict[ValueInDict, Any] = {}
self._items: dict[ValueWrapper, Any] = {}
self._default_factory = default_factory
if iter is not None:
for key, val in iter.items():
self[key] = val

def update(self, other_dict):
for key, val in other_dict:
self[ValueInDict(key)] = val
self[ValueWrapper(key)] = val

def keys(self):
for key in self._items.keys():
Expand All @@ -59,6 +62,8 @@ def items(self):
yield key.value, val

def __setitem__(self, other_key, other_val: Any):
if not isinstance(other_key, ValueWrapper):
other_key = ValueWrapper(other_key)
self._items[other_key] = other_val

def __getitem__(self, other_key):
Expand Down Expand Up @@ -89,7 +94,49 @@ def __iter__(self):
return self.keys()

def __contains__(self, other_key):
return ValueInDict(other_key) in self._items
return other_key in self._items


class ValueSet:
def __init__(
self, iter: Sequence[ValueWrapper] | set[ValueWrapper] | None = None
):
self._values: set[ValueWrapper] = set()
if iter is not None:
for val in iter:
self.add(val)

def add(self, other_val):
other_val = ValueWrapper(other_val)
if not self.__contains__(other_val):
self._values.add(other_val)

def update(self, other_set: set):
for val in other_set:
self.add(ValueWrapper(val))

def __and__(self, other_set: ValueSet):
ret = ValueSet()
for val in self._values:
if val in other_set:
ret.add(val)
return ret

def __or__(self, other_set: ValueSet):
return ValueSet(self._values | other_set._values)

def __bool__(self):
return bool(self._values)

def __len__(self):
return len(self._values)

def __iter__(self):
for val in self._values:
yield val.value

def __contains__(self, other_val):
return other_val in self._values


class State:
Expand Down
57 changes: 1 addition & 56 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections.abc import Sequence

import paddle.pir
from paddle.autograd.backward_utils import State
from paddle.autograd.backward_utils import State, ValueSet

"""
grad: for templete test, will combine in paddle.grad .
Expand All @@ -29,61 +29,6 @@
__all__ = ['grad', 'calc_gradient', 'calc_gradient_helper']


class ValueInSet:
def __init__(self, value) -> None:
self.value = value

def __hash__(self) -> int:
return hash(self.value)

def __eq__(self, other) -> bool:
if isinstance(other, ValueInSet):
other = other.value
return self.value.is_same(other)


class ValueSet:
def __init__(
self, iter: Sequence[ValueInSet] | set[ValueInSet] | None = None
):
self._values: set[ValueInSet] = set()
if iter is not None:
for val in iter:
self.add(val)

def add(self, other_val):
other_val = ValueInSet(other_val)
if not self.__contains__(other_val):
self._values.add(other_val)

def update(self, other_set: set):
for val in other_set:
self.add(ValueInSet(val))

def __and__(self, other_set: ValueSet):
ret = ValueSet()
for val in self._values:
if val in other_set:
ret.add(val)
return ret

def __or__(self, other_set: ValueSet):
return ValueSet(self._values | other_set._values)

def __bool__(self):
return bool(self._values)

def __len__(self):
return len(self._values)

def __iter__(self):
for val in self._values:
yield val.value

def __contains__(self, other_val):
return ValueInSet(other_val) in self._values


def check_type(input, input_name, expected_type, op_name, extra_message=''):
if not isinstance(input, expected_type):
raise TypeError(
Expand Down