Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix fuse derivative corner case and add UT #748

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions ppsci/autodiff/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,17 @@ def __call__(
self.Js[key] = _Jacobian(ys, xs)
return self.Js[key](i, j, retain_graph, create_graph)
else:
grads = paddle.grad(
xs_require: List["paddle.Tensor"] = [
xs[i] for i in range(len(xs)) if (ys, xs[i]) not in self.Js
]
grads_require = paddle.grad(
ys,
xs,
xs_require,
create_graph=create_graph,
retain_graph=retain_graph,
)

idx = 0
Js_list = []
for k, xs_ in enumerate(xs):
key = (ys, xs_)
Expand All @@ -148,7 +153,8 @@ def __call__(
f"{xs_.shape}"
)
if key not in self.Js:
self.Js[key] = _Jacobian(ys, xs_, {0: grads[k]})
self.Js[key] = _Jacobian(ys, xs_, {0: grads_require[idx]})
idx += 1
Js_list.append(self.Js[key](i, j, retain_graph, create_graph))
return Js_list

Expand Down
3 changes: 2 additions & 1 deletion ppsci/equation/pde/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def __init__(self):

self.detach_keys: Optional[Tuple[str, ...]] = None

@staticmethod
def create_symbols(
self, symbol_str: str
symbol_str: str,
) -> Union[sympy.Symbol, Tuple[sympy.Symbol, ...]]:
"""Create symbols

Expand Down
32 changes: 23 additions & 9 deletions ppsci/utils/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def __init__(
self.create_graph = create_graph
self.retain_graph = retain_graph
self._apply_func = self._derivate_operator_func
self.merged = False

def forward(self, data_dict: DATA_DICT):
# use cache
Expand Down Expand Up @@ -893,6 +894,9 @@ def _expr_to_callable_nodes(
# skip sdf function since it is always already given in data_dict
if callable_nodes_group[i][j].expr.args[0].name == "sdf":
continue
# skip merged node
if callable_nodes_group[i][j].merged:
continue

candidate_pos = [[i, j]]
for ii in range(len(callable_nodes_group)):
Expand All @@ -904,6 +908,9 @@ def _expr_to_callable_nodes(
# skip same node
if i == ii and j == jj:
continue
# skip merged node
if callable_nodes_group[ii][jj].merged:
continue

# has same function item
if (
Expand Down Expand Up @@ -932,17 +939,24 @@ def _expr_to_callable_nodes(
f" fuse node sequence: {fused_node_seq} at position: ([{gid0}][{nid0}])"
)

# mark merged node
for i, (gid, nid) in enumerate(candidate_pos):
assert isinstance(callable_nodes_group[gid][nid], DerivativeNode)
callable_nodes_group[gid][nid].merged = True

# replace first mergable node with fused node sequence(packed in list)
# then mask the rest merged node to None(except [gid0, nid0])
for i, (gid, nid) in enumerate(candidate_pos):
if i == 0:
callable_nodes_group[gid0][nid0] = fused_node_seq
else:
# keep the end node of each group to avoid generating empty callable
# node sequence, this will not effect performance since cache strategy
# in Node.forward
if nid != len(callable_nodes_group[gid]) - 1:
callable_nodes_group[gid][nid] = None
for i, (gid, nid) in enumerate(candidate_pos[1:]):
# keep the end node of each group to avoid generating empty callable
# node sequence, this will not effect performance since cache strategy
# in Node.forward
if nid != len(callable_nodes_group[gid]) - 1:
callable_nodes_group[gid][nid] = None

if nid0 == len(callable_nodes_group[gid0]) - 1:
callable_nodes_group[gid0].insert(nid0, fused_node_seq)
else:
callable_nodes_group[gid0][nid0] = fused_node_seq

# re-organize callable_nodes_group, remove None element and unpack list
for i in range(len(callable_nodes_group)):
Expand Down
58 changes: 58 additions & 0 deletions test/utils/test_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,63 @@ def test_multi_model_and_sdf():
assert np.allclose(out_var_tensor.numpy(), out_var_reference.numpy())


def test_complicated_symbolic():
x_ten = paddle.randn([32, 1])
x_ten.stop_gradient = False
y_ten = paddle.randn([32, 1])
y_ten.stop_gradient = False
z_ten = paddle.randn([32, 1])
z_ten.stop_gradient = False

input_data = {
"x": x_ten,
"y": y_ten,
"z": z_ten,
}
x_sp, y_sp, z_sp = ppsci.equation.PDE.create_symbols("x y z")
f = sp.Function("f")(x_sp, y_sp, z_sp)
# g = sp.Function("g")(x_sp, y_sp, z_sp)
model_f = ppsci.arch.MLP((x_sp.name, y_sp.name, z_sp.name), (f.name,), 3, 6)
# model_g = ppsci.arch.MLP((x_sp.name, y_sp.name, z_sp.name), (f.name,), 3, 6)

for test_id in range(100):

def random_derivative(state):
ret = f
for k in range(4):
if state & (1 << k):
ret = ret.diff(x_sp)
else:
ret = ret.diff(y_sp)
return ret

state1 = np.random.randint(0, 1 << 4)
state2 = np.random.randint(0, 1 << 4)
state3 = np.random.randint(0, 1 << 4)
state4 = np.random.randint(0, 1 << 4)
targets = [
random_derivative(state1),
random_derivative(state2),
random_derivative(state3),
random_derivative(state4),
]
eqs_fuse = ppsci.lambdify(
targets,
model_f,
fuse_derivative=True,
)
eqs_expected = ppsci.lambdify(
targets,
model_f,
fuse_derivative=True,
)

for i in range(len(targets)):
output_fuse = eqs_fuse[i](input_data)
output_expected = eqs_expected[i](input_data)
assert np.allclose(output_fuse.numpy(), output_expected.numpy())
ppsci.autodiff.clear()


if __name__ == "__main__":
pytest.main()