Skip to content

Commit

Permalink
[Fix] Fix fuse derivative corner case and add UT (PaddlePaddle#748)
Browse files Browse the repository at this point in the history
* fix coner case for FusedDerivativeNode is placed at the end of nodes group

* add UT for fuse derivatives

* fix bug
  • Loading branch information
HydrogenSulfate authored Jan 11, 2024
1 parent 8063a77 commit 4c8a961
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 13 deletions.
12 changes: 9 additions & 3 deletions ppsci/autodiff/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,17 @@ def __call__(
self.Js[key] = _Jacobian(ys, xs)
return self.Js[key](i, j, retain_graph, create_graph)
else:
grads = paddle.grad(
xs_require: List["paddle.Tensor"] = [
xs[i] for i in range(len(xs)) if (ys, xs[i]) not in self.Js
]
grads_require = paddle.grad(
ys,
xs,
xs_require,
create_graph=create_graph,
retain_graph=retain_graph,
)

idx = 0
Js_list = []
for k, xs_ in enumerate(xs):
key = (ys, xs_)
Expand All @@ -148,7 +153,8 @@ def __call__(
f"{xs_.shape}"
)
if key not in self.Js:
self.Js[key] = _Jacobian(ys, xs_, {0: grads[k]})
self.Js[key] = _Jacobian(ys, xs_, {0: grads_require[idx]})
idx += 1
Js_list.append(self.Js[key](i, j, retain_graph, create_graph))
return Js_list

Expand Down
3 changes: 2 additions & 1 deletion ppsci/equation/pde/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def __init__(self):

self.detach_keys: Optional[Tuple[str, ...]] = None

@staticmethod
def create_symbols(
self, symbol_str: str
symbol_str: str,
) -> Union[sympy.Symbol, Tuple[sympy.Symbol, ...]]:
"""Create symbols
Expand Down
32 changes: 23 additions & 9 deletions ppsci/utils/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def __init__(
self.create_graph = create_graph
self.retain_graph = retain_graph
self._apply_func = self._derivate_operator_func
self.merged = False

def forward(self, data_dict: DATA_DICT):
# use cache
Expand Down Expand Up @@ -893,6 +894,9 @@ def _expr_to_callable_nodes(
# skip sdf function since it is always already given in data_dict
if callable_nodes_group[i][j].expr.args[0].name == "sdf":
continue
# skip merged node
if callable_nodes_group[i][j].merged:
continue

candidate_pos = [[i, j]]
for ii in range(len(callable_nodes_group)):
Expand All @@ -904,6 +908,9 @@ def _expr_to_callable_nodes(
# skip same node
if i == ii and j == jj:
continue
# skip merged node
if callable_nodes_group[ii][jj].merged:
continue

# has same function item
if (
Expand Down Expand Up @@ -932,17 +939,24 @@ def _expr_to_callable_nodes(
f" fuse node sequence: {fused_node_seq} at position: ([{gid0}][{nid0}])"
)

# mark merged node
for i, (gid, nid) in enumerate(candidate_pos):
assert isinstance(callable_nodes_group[gid][nid], DerivativeNode)
callable_nodes_group[gid][nid].merged = True

# replace first mergable node with fused node sequence(packed in list)
# then mask the rest merged node to None(except [gid0, nid0])
for i, (gid, nid) in enumerate(candidate_pos):
if i == 0:
callable_nodes_group[gid0][nid0] = fused_node_seq
else:
# keep the end node of each group to avoid generating empty callable
# node sequence, this will not effect performance since cache strategy
# in Node.forward
if nid != len(callable_nodes_group[gid]) - 1:
callable_nodes_group[gid][nid] = None
for i, (gid, nid) in enumerate(candidate_pos[1:]):
# keep the end node of each group to avoid generating empty callable
# node sequence, this will not effect performance since cache strategy
# in Node.forward
if nid != len(callable_nodes_group[gid]) - 1:
callable_nodes_group[gid][nid] = None

if nid0 == len(callable_nodes_group[gid0]) - 1:
callable_nodes_group[gid0].insert(nid0, fused_node_seq)
else:
callable_nodes_group[gid0][nid0] = fused_node_seq

# re-organize callable_nodes_group, remove None element and unpack list
for i in range(len(callable_nodes_group)):
Expand Down
58 changes: 58 additions & 0 deletions test/utils/test_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,63 @@ def test_multi_model_and_sdf():
assert np.allclose(out_var_tensor.numpy(), out_var_reference.numpy())


def test_complicated_symbolic():
x_ten = paddle.randn([32, 1])
x_ten.stop_gradient = False
y_ten = paddle.randn([32, 1])
y_ten.stop_gradient = False
z_ten = paddle.randn([32, 1])
z_ten.stop_gradient = False

input_data = {
"x": x_ten,
"y": y_ten,
"z": z_ten,
}
x_sp, y_sp, z_sp = ppsci.equation.PDE.create_symbols("x y z")
f = sp.Function("f")(x_sp, y_sp, z_sp)
# g = sp.Function("g")(x_sp, y_sp, z_sp)
model_f = ppsci.arch.MLP((x_sp.name, y_sp.name, z_sp.name), (f.name,), 3, 6)
# model_g = ppsci.arch.MLP((x_sp.name, y_sp.name, z_sp.name), (f.name,), 3, 6)

for test_id in range(100):

def random_derivative(state):
ret = f
for k in range(4):
if state & (1 << k):
ret = ret.diff(x_sp)
else:
ret = ret.diff(y_sp)
return ret

state1 = np.random.randint(0, 1 << 4)
state2 = np.random.randint(0, 1 << 4)
state3 = np.random.randint(0, 1 << 4)
state4 = np.random.randint(0, 1 << 4)
targets = [
random_derivative(state1),
random_derivative(state2),
random_derivative(state3),
random_derivative(state4),
]
eqs_fuse = ppsci.lambdify(
targets,
model_f,
fuse_derivative=True,
)
eqs_expected = ppsci.lambdify(
targets,
model_f,
fuse_derivative=True,
)

for i in range(len(targets)):
output_fuse = eqs_fuse[i](input_data)
output_expected = eqs_expected[i](input_data)
assert np.allclose(output_fuse.numpy(), output_expected.numpy())
ppsci.autodiff.clear()


if __name__ == "__main__":
pytest.main()

0 comments on commit 4c8a961

Please sign in to comment.