Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rm dict in module apply #9137

Merged
merged 3 commits into from
Sep 23, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions python/oneflow/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(self):
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
self._is_ddp_module = False
self._oneflow_internal_module_tensor_applied_dict__ = None

def __getstate__(self):
if not self._is_ddp_module:
Expand Down Expand Up @@ -968,14 +969,18 @@ def register_forward_hook(self, hook: Callable[..., None]) -> None:
"""
self._forward_hooks[len(self._forward_hooks)] = hook

def _apply(self, fn, applied_dict=None):
def _apply(self, fn):
# A dict to store tensors that has already been applied.
# There is no need to apply multiple times on a same tensor.
if applied_dict is None:
applied_dict = dict()
if self._oneflow_internal_module_tensor_applied_dict__ is None:
self._oneflow_internal_module_tensor_applied_dict__ = dict()

for module in self.children():
module._apply(fn, applied_dict)
module._oneflow_internal_module_tensor_applied_dict__ = (
self._oneflow_internal_module_tensor_applied_dict__
)
module._apply(fn)
module._oneflow_internal_module_tensor_applied_dict__ = None

def can_use_assign_copy(tensor, tensor_applied):
return tensor.is_local == tensor_applied.is_local
Expand All @@ -985,7 +990,7 @@ def can_use_assign_copy(tensor, tensor_applied):
continue

need_apply = False
if param not in applied_dict:
if param not in self._oneflow_internal_module_tensor_applied_dict__:
need_apply = True
assert isinstance(param, Parameter)
assert param.is_leaf
Expand All @@ -1000,31 +1005,45 @@ def can_use_assign_copy(tensor, tensor_applied):
grad_applied.requires_grad = param.grad.requires_grad
param_applied.grad = grad_applied
else:
param_applied = applied_dict[param]
param_applied = self._oneflow_internal_module_tensor_applied_dict__[
param
]

if can_use_assign_copy(param_applied, param):
if need_apply:
self._parameters[key].data = param_applied
applied_dict[param] = param_applied
self._oneflow_internal_module_tensor_applied_dict__[
param
] = param_applied
else:
# The parameter's data has already been set when it can use assign copy.
pass
else:
if need_apply:
new_param = Parameter(param_applied, param.requires_grad)
self._parameters[key] = new_param
applied_dict[param] = new_param
self._oneflow_internal_module_tensor_applied_dict__[
param
] = new_param
else:
self._parameters[key] = applied_dict[param]
self._parameters[
key
] = self._oneflow_internal_module_tensor_applied_dict__[param]

for (key, buf) in self._buffers.items():
if buf is not None:
if buf not in applied_dict:
if buf not in self._oneflow_internal_module_tensor_applied_dict__:
buf_applied = fn(buf)
self._buffers[key] = buf_applied
applied_dict[buf] = buf_applied
self._oneflow_internal_module_tensor_applied_dict__[
buf
] = buf_applied
else:
self._buffers[key] = applied_dict[buf]
self._buffers[
key
] = self._oneflow_internal_module_tensor_applied_dict__[buf]

self._oneflow_internal_module_tensor_applied_dict__ = None
return self

def apply(self: T, fn: Callable[["Module"], None]) -> T:
Expand Down