Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pir]Fix Value eq error when using set #58896

Merged
merged 53 commits into from
Dec 18, 2023
Merged
Changes from 1 commit
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
6539a97
[Pir]Fix Value eq error when using set
Nov 10, 2023
3899a7d
Fix iter
Nov 10, 2023
7f6c298
Refine code
Nov 10, 2023
0793a38
add ValueDict
zrr1999 Nov 10, 2023
f8abfa3
fix
zrr1999 Nov 13, 2023
6e5947b
update valueset
zrr1999 Nov 13, 2023
09f5eac
update dict
zrr1999 Nov 13, 2023
ad7461d
fix bug
zrr1999 Nov 14, 2023
531621c
improve
zrr1999 Nov 14, 2023
6b252fc
fix
zrr1999 Nov 14, 2023
7eb2685
fix contains
zrr1999 Nov 14, 2023
d7792bd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 15, 2023
bc174c9
Fix set(Value) or dict[Value] = xx code
Nov 15, 2023
8a6f0a7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 16, 2023
d7b38e8
Fix ut
Nov 16, 2023
963d0f8
Fix double grad ut
Nov 16, 2023
d299a50
Fix ut
Nov 17, 2023
8bae46b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 24, 2023
ba7a2cd
Forbid opresult hash
Nov 24, 2023
4a50818
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Nov 26, 2023
ace1ae1
Fix set and dict
Nov 26, 2023
cd3e2ea
Fix dy2st
Nov 26, 2023
470a2aa
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 7, 2023
faefbe8
Refine code
Dec 7, 2023
526eecd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 7, 2023
78cbe1b
Forbid eq
Dec 7, 2023
9be9129
Fix map and value_eq
Dec 7, 2023
176e09b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 8, 2023
ea46d6b
Fix None hash
Dec 8, 2023
dcaf6c4
Fix decomp
Dec 8, 2023
1df44e2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 11, 2023
55cbfcc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 12, 2023
b8b6416
Refine value set/dict
Dec 12, 2023
16bc3c6
Add hash
Dec 12, 2023
3b0a7cd
fix clone program, return a assciated array.
2742195759 Dec 12, 2023
1a7767d
Merge commit 'refs/pull/58896/head' of https://github.com/PaddlePaddl…
2742195759 Dec 12, 2023
7ab860c
Fix iter
Dec 12, 2023
31686de
Fix code
Dec 13, 2023
294abe4
Fix backward
Dec 13, 2023
b168af9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 13, 2023
96d8cf3
Fix decomp and add copy()
Dec 13, 2023
e3637da
Format code
Dec 13, 2023
ef65985
Fix ut
Dec 14, 2023
7e6422c
Fix layer params set
Dec 14, 2023
a43cd6e
Fix prim op test
Dec 14, 2023
cf1ce82
Fix named_parameters
Dec 14, 2023
88b3e9f
Support value __eq__
Dec 15, 2023
4ec0f71
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 15, 2023
2611c9f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 15, 2023
7b8cf7e
Fix test cond ==
Dec 15, 2023
6af0553
Add ut for value set/dict
Dec 18, 2023
b8953f2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 18, 2023
4464870
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Dec 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refine value set/dict
  • Loading branch information
0x45f committed Dec 12, 2023
commit b8b64164387ae6ae6390ab07b25a629450553dcd
86 changes: 42 additions & 44 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,61 +25,68 @@

class ValueWrapper:
def __init__(self, value) -> None:
self.value = value.value if isinstance(value, ValueWrapper) else value
if isinstance(value, ValueWrapper):
assert isinstance(value._value, (type(None), pir.Value))
else:
assert isinstance(value, (type(None), pir.Value))
self._value = value._value if isinstance(value, ValueWrapper) else value

def __hash__(self) -> int:
if isinstance(self.value, pir.Value):
return self.value.hash()
if isinstance(self._value, pir.Value):
return self._value.hash()
else:
return hash(self.value)
return hash(self._value)

def __eq__(self, other) -> bool:
tmp_other = other.value if isinstance(other, ValueWrapper) else other
if self.value is None:
return tmp_other is None
return self.value.is_same(tmp_other)
if not isinstance(ValueWrapper):
warnings.warn(
f'In ValueWrapper.__eq__ expected type of `other` is ValueWrapper but received {other.__class__}.'
)
return False

if self._value is None or other._value is None:
return self._value is None and other._value is None
return self._value.is_same(other._value)


class ValueDict:
def __init__(
self,
iter: dict[ValueWrapper, Any] | None = None,
iter,
*,
default_factory=None,
):
self._items: dict[ValueWrapper, Any] = {}
self._items: dict[ValueWrapper] = {}
self._default_factory = default_factory
if iter is not None:
for key, val in iter.items():
self[key] = val

def update(self, other_dict):
for key, val in other_dict:
self[ValueWrapper(key)] = val
self[key] = val

def keys(self):
for key in self._items.keys():
yield key.value
yield key._value

def values(self):
return self._items.values()

def items(self):
for key, val in self._items.items():
yield key.value, val
yield key._value, val

def __setitem__(self, key, val: Any):
tmp_key = key if isinstance(key, ValueWrapper) else ValueWrapper(key)
self._items[tmp_key] = val
self._items[ValueWrapper(key)] = val

def __getitem__(self, key):
tmp_key = key if isinstance(key, ValueWrapper) else ValueWrapper(key)
if not self.__contains__(tmp_key):
if not self.__contains__(key):
if self._default_factory is not None:
self[tmp_key] = self._default_factory()
self[key] = self._default_factory()
else:
raise KeyError(f'{key} not in ValueDict')
return self._items[tmp_key]
raise KeyError(f'{key} is not in ValueDict')
return self._items[ValueWrapper(key)]

def __bool__(self):
return bool(self._items)
Expand All @@ -91,53 +98,44 @@ def __iter__(self):
return self.keys()

def __contains__(self, key):
if isinstance(key, ValueWrapper):
return key in self._items
return ValueWrapper(key) in self._items


class ValueSet:
def __init__(
self, iter: Sequence[ValueWrapper] | set[ValueWrapper] | None = None
):
self._values: set[ValueWrapper] = set()
self._set: set[ValueWrapper] = set()
if iter is not None:
for val in iter:
self.add(val)

def add(self, val):
tmp_val = ValueWrapper(val)
if not self.__contains__(tmp_val):
self._values.add(tmp_val)
if not self.__contains__(val):
self._set.add(ValueWrapper(val))

def update(self, other_set: set):
for val in other_set:
self.add(ValueWrapper(val))
def update(self, other: set):
for val in other:
self.add(val)

def __and__(self, other_set: ValueSet):
ret = ValueSet()
for val in self._values:
if val in other_set:
ret.add(val)
return ret
def __and__(self, other: ValueSet):
return ValueSet(self._set & other._set)

def __or__(self, other_set: ValueSet):
return ValueSet(self._values | other_set._values)
def __or__(self, other: ValueSet):
return ValueSet(self._set | other._set)

def __bool__(self):
return bool(self._values)
return bool(self._set)

def __len__(self):
return len(self._values)
return len(self._set)

def __iter__(self):
for val in self._values:
yield val.value
for val in self._set:
yield val._value

def __contains__(self, val):
if isinstance(val, ValueWrapper):
return val in self._values
return ValueWrapper(val) in self._values
return ValueWrapper(val) in self._set


class State:
Expand Down