Skip to content

Commit

Permalink
[Pir]Fix Value eq error when using set (PaddlePaddle#58896)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f authored and HermitSun committed Dec 21, 2023
1 parent fcc7597 commit 541e970
Show file tree
Hide file tree
Showing 17 changed files with 345 additions and 97 deletions.
17 changes: 10 additions & 7 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -761,8 +761,8 @@ void BindValue(py::module *m) {
[](Value self) {
return paddle::dialect::scale(self, -1.0, 0.0, true);
})
.def("__eq__", &Value::operator==)
.def("__hash__", [](Value self) { return std::hash<pir::Value>{}(self); })
.def("is_same", &Value::operator==)
.def("hash", [](Value self) { return std::hash<pir::Value>{}(self); })
.def("__repr__", &Value2String);
// For basaic operators
OVERRIDE_OPERATOR_FOR_EACH(__add__, add, 1.0, other, true);
Expand Down Expand Up @@ -1029,7 +1029,8 @@ static auto GetNoNeedBufferValue(const ::pir::Block *whole_block,
no_need_buffer_values.end());
}

using OpResultMap = std::unordered_map<pir::OpResult, pir::OpResult>;
using OpResultMap =
std::pair<std::vector<pir::OpResult>, std::vector<pir::OpResult>>;
std::pair<std::shared_ptr<Program>, OpResultMap> CloneProgram(
const Program &program) {
// Limitation of this function:
Expand All @@ -1042,12 +1043,14 @@ std::pair<std::shared_ptr<Program>, OpResultMap> CloneProgram(
auto *cloned_op = BuildOpFrom(&op, value_map);
cloned_program->block()->push_back(cloned_op);
}
std::unordered_map<pir::OpResult, pir::OpResult> op_result_map;
std::vector<pir::OpResult> associated_array_key, associated_array_value;
for (auto &pair : value_map) {
op_result_map[pair.first.dyn_cast<pir::OpResult>()] =
pair.second.dyn_cast<pir::OpResult>();
associated_array_key.push_back(pair.first.dyn_cast<pir::OpResult>());
associated_array_value.push_back(pair.second.dyn_cast<pir::OpResult>());
}
return std::make_pair(cloned_program, op_result_map);
return std::make_pair(
cloned_program,
std::make_pair(associated_array_key, associated_array_value));
}

void AppendSetParameter(Program *forward_program,
Expand Down
149 changes: 142 additions & 7 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,149 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import collections
import warnings
from collections.abc import Sequence
from typing import Any

from paddle import pir
from paddle.base import core
from paddle.base.wrapped_decorator import signature_safe_contextmanager


class ValueWrapper:
def __init__(self, value) -> None:
if isinstance(value, ValueWrapper):
assert isinstance(value._value, (type(None), pir.Value))
else:
assert isinstance(value, (type(None), pir.Value))
self._value = value._value if isinstance(value, ValueWrapper) else value

def __hash__(self) -> int:
if isinstance(self._value, pir.Value):
return self._value.hash()
else:
return hash(self._value)

def __eq__(self, other) -> bool:
if not isinstance(other, ValueWrapper):
warnings.warn(
f'In ValueWrapper.__eq__ expected type of `other` is ValueWrapper but received {other.__class__}.'
)
return False

if self._value is None or other._value is None:
return self._value is None and other._value is None
return self._value.is_same(other._value)


class ValueDict:
def __init__(
self,
iter=None,
*,
default_factory=None,
):
self._items: dict[ValueWrapper] = {}
self._default_factory = default_factory
if iter is not None:
for key, val in iter.items():
self[key] = val

def copy(self):
ret = ValueDict()
ret._items = self._items.copy()
ret._default_factory = self._default_factory
return ret

def update(self, other_dict):
for key, val in other_dict.items():
self[key] = val

def keys(self):
for key in self._items.keys():
yield key._value

def values(self):
return self._items.values()

def items(self):
for key, val in self._items.items():
yield key._value, val

def pop(self, key):
if not self.__contains__(key):
raise KeyError(f'{key} is not in ValueDict')
return self._items.pop(ValueWrapper(key))

def __setitem__(self, key, val: Any):
self._items[ValueWrapper(key)] = val

def __getitem__(self, key):
if not self.__contains__(key):
if self._default_factory is not None:
self[key] = self._default_factory()
else:
raise KeyError(f'{key} is not in ValueDict')
return self._items[ValueWrapper(key)]

def __bool__(self):
return bool(self._items)

def __len__(self):
return len(self._items)

def __iter__(self):
return self.keys()

def __contains__(self, key):
return ValueWrapper(key) in self._items


class ValueSet:
def __init__(
self, iter: Sequence[ValueWrapper] | set[ValueWrapper] | None = None
):
self._set: set[ValueWrapper] = set()
if iter is not None:
for val in iter:
self.add(val)

def copy(self):
ret = ValueSet()
ret._set = self._set.copy()
return ret

def add(self, val):
if not self.__contains__(val):
self._set.add(ValueWrapper(val))

def update(self, other: set):
for val in other:
self.add(val)

def __and__(self, other: ValueSet):
return ValueSet(self._set & other._set)

def __or__(self, other: ValueSet):
return ValueSet(self._set | other._set)

def __bool__(self):
return bool(self._set)

def __len__(self):
return len(self._set)

def __iter__(self):
for val in self._set:
yield val._value

def __contains__(self, val):
return ValueWrapper(val) in self._set


class State:
"""
record relationship of forward op/value and backward op/value
Expand All @@ -30,24 +165,24 @@ class State:
def __init__(self, block):
self.block = block
# opresult -> list(list(opresult))
self.value_to_valuegrad = collections.defaultdict(list)
self.value_to_sumvaluegrad = collections.defaultdict(list)
self.value_to_valuegrad = ValueDict(default_factory=list)
self.value_to_sumvaluegrad = ValueDict(default_factory=list)
# operation -> list(operation)
self.op_to_opgrad = collections.defaultdict(list)

# opresult -> list(opresult)
self.valuegrad_to_value = collections.defaultdict(list)
self.sumvaluegrad_to_value = collections.defaultdict(list)
self.valuegrad_to_value = ValueDict(default_factory=list)
self.sumvaluegrad_to_value = ValueDict(default_factory=list)
# operation -> list(operation)
self.opgrad_to_op = collections.defaultdict(list)
# only for controlflow
# inside_value is sub block value, which will yield to parent block,
# parant block value is outside_value
self.inside_value_to_outside_value_map = {}
self.inside_value_to_outside_value_map = ValueDict()

def turn_map(self) -> None:
self.valuegrad_to_value = collections.defaultdict(list)
self.sumvaluegrad_to_value = collections.defaultdict(list)
self.valuegrad_to_value = ValueDict(default_factory=list)
self.sumvaluegrad_to_value = ValueDict(default_factory=list)
self.opgrad_to_op = collections.defaultdict(list)

for k, v in self.value_to_valuegrad.items():
Expand Down
Loading

0 comments on commit 541e970

Please sign in to comment.