Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PPSCI Doc No.12、13、14、15、16、17】ppsci.arch.Arch #752

Merged
merged 6 commits into from
Jan 16, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 100 additions & 2 deletions ppsci/arch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,23 @@ def concat_to_tensor(

Returns:
Tuple[paddle.Tensor, ...]: Concatenated tensor.

Examples:
>>> import paddle
>>> import ppsci
>>> model = ppsci.arch.Arch()
>>> # fetch one tensor
>>> out = model.concat_to_tensor({'x':paddle.to_tensor(123)}, ('x',))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 这里的输入建议使用2D+float32 Tensor,比如[8,1]这种,0 DTensor不太有代表性
  2. 构造输入用paddle.randn(shape)更好

>>> print(out.dtype, out.shape)
paddle.int64 []
>>> # fetch more tensors
>>> out = model.concat_to_tensor(
... {'x1':paddle.to_tensor([[1, 2], [2, 3]]), 'x2':paddle.to_tensor([[3, 4], [4, 5]])},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

... ('x1', 'x2'),
... axis=0)
>>> print(out.dtype, out.shape)
paddle.int64 [4, 2]

"""
if len(keys) == 1:
return data_dict[keys[0]]
Expand All @@ -90,6 +107,23 @@ def split_to_dict(

Returns:
Dict[str, paddle.Tensor]: Dict contains tensor.

Examples:
>>> import paddle
>>> import ppsci
>>> model = ppsci.arch.Arch()
>>> # split one tensor
>>> out = model.split_to_dict(paddle.to_tensor(123), ('x',))
>>> for k, v in out.items():
... print(f"{k} {v.dtype} {v.shape}")
x paddle.int64 []
>>> # split more tensors
>>> out = model.split_to_dict(paddle.to_tensor([[1, 2], [2, 3]]), ('x1', 'x2'), axis=0)
>>> for k, v in out.items():
... print(f"{k} {v.dtype} {v.shape}")
x1 paddle.int64 [1, 2]
x2 paddle.int64 [1, 2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,类型和形状可以改成二维+flaot32


"""
if len(keys) == 1:
return {keys[0]: data_tensor}
Expand All @@ -105,6 +139,27 @@ def register_input_transform(
Args:
transform (Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]]):
Input transform of network, receive a single tensor dict and return a single tensor dict.

Examples:
>>> import ppsci
>>> def transform_in(in_):
... x = in_["x"]
... # transform input
... x_ = 2.0 * x
... input_trans = {"2x": x_}
... return input_trans
>>> # `MLP` inherits from `Arch`
>>> model = ppsci.arch.MLP(
... input_keys=("2x",),
... output_keys=("y",),
... num_layers=5,
... hidden_size=32)
>>> model.register_input_transform(transform_in)
>>> out = model({"x":paddle.rand([64, 64, 1])})
>>> for k, v in out.items():
... print(f"{k} {v.dtype} {v.shape}")
y paddle.float32 [64, 64, 1]

"""
self._input_transform = transform

Expand All @@ -121,18 +176,61 @@ def register_output_transform(
transform (Callable[[Dict[str, paddle.Tensor], Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]]):
Output transform of network, receive two single tensor dict(raw input
and raw output) and return a single tensor dict(transformed output).

Examples:
>>> import ppsci
>>> def transform_out(in_, out):
... x = in_["x"]
... y = out["y"]
... u = 2.0 * x * y
... output_trans = {"u": u}
... return output_trans
>>> # `MLP` inherits from `Arch`
>>> model = ppsci.arch.MLP(
... input_keys=("x",),
... output_keys=("y",),
... num_layers=5,
... hidden_size=32)
>>> model.register_output_transform(transform_out)
>>> out = model({"x":paddle.rand([64, 64, 1])})
>>> for k, v in out.items():
... print(f"{k} {v.dtype} {v.shape}")
u paddle.float32 [64, 64, 1]

"""
self._output_transform = transform

def freeze(self):
"""Freeze all parameters."""
"""Freeze all parameters.

Examples:
>>> import ppsci
>>> model = ppsci.arch.Arch()
>>> # freeze all parameters and make model `eval`
>>> model.freeze()
Comment on lines +206 to +210
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

末尾加上如下验证代码:

assert p.training is False

for p in model.parameters():
    assert p.stop_gradient is True

>>> assert not model.training
>>> for p in model.parameters():
... assert p.stop_gradient

"""
for param in self.parameters():
param.stop_gradient = True

self.eval()

def unfreeze(self):
"""Unfreeze all parameters."""
"""Unfreeze all parameters.

Examples:
>>> import ppsci
>>> model = ppsci.arch.Arch()
>>> # unfreeze all parameters and make model `train`
>>> model.unfreeze()
>>> assert model.training
>>> for p in model.parameters():
... assert not p.stop_gradient

"""
for param in self.parameters():
param.stop_gradient = False

Expand Down