Skip to content

Commit e4a93b0

Browse files
2742195759Aurelius84thisjiangcxxly
committed
Cxx prim custom vjp (#8)
* [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557) --------- Co-authored-by: jiangcheng <thisjiang@qq.com> * [prim] enable dygraph_to_static to support custom_vjp * Pr 50885 (#7) * [CINN]Enhance CacheKey hash logic by considering input dtypes (PaddlePaddle#50557) * [CINN]Enhance CacheKey hash logic by considering input dtypes --------- Co-authored-by: jiangcheng <thisjiang@qq.com> * [prim] enable dygraph_to_static to support custom_vjp * fix code in a dy2static-friendly way. * [dystatic] add hooker for prim --------- Co-authored-by: Aurelius84 <zhangliujie@baidu.com> Co-authored-by: jiangcheng <thisjiang@qq.com> Co-authored-by: cxxly <chenxx_id@163.com> * [prim] enable dygraph_to_static to support custom_vjp * fix cast prim and vjp dtype mapping error bug * [dy2static-ci] fix dy2static ci errors. --------- Co-authored-by: Aurelius84 <zhangliujie@baidu.com> Co-authored-by: jiangcheng <thisjiang@qq.com> Co-authored-by: cxxly <chenxx_id@163.com>
1 parent 5dda91a commit e4a93b0

File tree

8 files changed

+62
-18
lines changed

8 files changed

+62
-18
lines changed

python/paddle/fluid/framework.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3751,7 +3751,6 @@ def __init__(self, program, idx):
37513751
self.vars = collections.OrderedDict() # var_name --> var
37523752
self.ops = list() # operator list
37533753
self.program = program
3754-
self.removed_vars = collections.OrderedDict()
37553754

37563755
def __str__(self):
37573756
return self._to_readable_code()

python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,12 @@ def train(self, use_prim):
7777
def check_prim(self, net, use_prim):
7878
if not use_prim:
7979
return
80-
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
80+
fwd_ops = [
81+
op.type
82+
for op in net.forward.get_concrete_program(self.x)[1]
83+
.train_program.block(0)
84+
.ops
85+
]
8186
# Ensure that softmax is splitted into small ops
8287
self.assertTrue('softmax' not in fwd_ops)
8388

@@ -128,7 +133,12 @@ def train(self, use_prim):
128133
def check_prim(self, net, use_prim):
129134
if not use_prim:
130135
return
131-
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
136+
fwd_ops = [
137+
op.type
138+
for op in net.forward.get_concrete_program(self.x)[1]
139+
.train_program.block(0)
140+
.ops
141+
]
132142
all_ops = [
133143
op.type
134144
for op in net.forward.program_cache.last()[-1][-1]

python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_gelu.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def _train(self, use_prim, approximate, data):
7777
net = apply_to_static(net, use_prim)
7878

7979
res = []
80+
self.x = data
8081
for _ in range(10):
8182
out = net(data)
8283
loss = paddle.mean(out)
@@ -92,7 +93,12 @@ def _train(self, use_prim, approximate, data):
9293
def check_prim(self, net, use_prim):
9394
if not use_prim:
9495
return
95-
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
96+
fwd_ops = [
97+
op.type
98+
for op in net.forward.get_concrete_program(self.x)[1]
99+
.train_program.block(0)
100+
.ops
101+
]
96102
# Ensure that gelu is splitted into small ops
97103
self.assertTrue('gelu' not in fwd_ops)
98104

python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,14 @@ def train(self, use_prim):
8989
def check_prim(self, net, use_prim):
9090
if not use_prim:
9191
return
92-
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
92+
fwd_ops = [
93+
op.type
94+
for op in net.forward.get_concrete_program(self.x, self.w, self.b)[
95+
1
96+
]
97+
.train_program.block(0)
98+
.ops
99+
]
93100
# Ensure that layer_norm is splitted into small ops
94101
self.assertTrue('layer_norm' not in fwd_ops)
95102

@@ -150,7 +157,14 @@ def train(self, use_prim):
150157
def check_prim(self, net, use_prim):
151158
if not use_prim:
152159
return
153-
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
160+
fwd_ops = [
161+
op.type
162+
for op in net.forward.get_concrete_program(self.x, self.w, self.b)[
163+
1
164+
]
165+
.train_program.block(0)
166+
.ops
167+
]
154168
# Ensure that layer_norm is splitted into small ops
155169
self.assertTrue('layer_norm' not in fwd_ops)
156170

python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_mean.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def _train(self, use_prim, data, axis, keep_dim):
8383
net = apply_to_static(net, use_prim)
8484

8585
res = []
86+
self.x = data
8687
for _ in range(10):
8788
out = net(data)
8889
loss = paddle.mean(out, axis, keep_dim)
@@ -99,7 +100,12 @@ def _train(self, use_prim, data, axis, keep_dim):
99100
def check_prim(self, net, use_prim):
100101
if not use_prim:
101102
return
102-
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
103+
fwd_ops = [
104+
op.type
105+
for op in net.forward.get_concrete_program(self.x)[1]
106+
.train_program.block(0)
107+
.ops
108+
]
103109
# Ensure that reduce_mean is splitted into small ops
104110
self.assertTrue('reduce_mean' not in fwd_ops)
105111

@@ -150,6 +156,7 @@ def _train(self, use_prim, data, axis, keep_dim):
150156
net = apply_to_static(net, use_prim)
151157

152158
res = []
159+
self.x = data
153160
for _ in range(10):
154161
out = net(data)
155162
loss = paddle.mean(out, axis, keep_dim)
@@ -166,7 +173,12 @@ def _train(self, use_prim, data, axis, keep_dim):
166173
def check_prim(self, net, use_prim):
167174
if not use_prim:
168175
return
169-
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
176+
fwd_ops = [
177+
op.type
178+
for op in net.forward.get_concrete_program(self.x)[1]
179+
.train_program.block(0)
180+
.ops
181+
]
170182
# Ensure that reduce_mean is splitted into small ops
171183
self.assertTrue('reduce_mean' not in fwd_ops)
172184

python/paddle/jit/dy2static/partial_program.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,7 @@ def _create_pure_fp16_program(self, is_infer_mode=False):
315315
def _create_forward_backward_train_program(self):
316316
whole_program = self._train_program
317317
# _, forward_end_op_index = self._infer_info('fp32', self._create_program)
318-
forward_end_op_index = self._forward_end_index_map[
319-
_hash_with_id(whole_program, self)
320-
]
318+
forward_end_op_index = self.get_forward_end_op_idx(whole_program)
321319
assert forward_end_op_index >= 0
322320

323321
return self._get_forward_backward_program_form(
@@ -438,11 +436,14 @@ def _infer_pure_fp16_program_id(self):
438436
def _param_grad_names(self):
439437
return _param_grad_names(self._train_program.desc, self._params)
440438

439+
def get_forward_end_op_idx(self, program):
440+
return self._forward_end_index_map[_hash_with_id(program, self)]
441+
441442
@LazyInitialized
442443
def _out_grad_names(self):
443444
return _out_grad_names(
444445
self._train_program.desc,
445-
self._create_program(is_infer_mode=True).desc.block(0).op_size(),
446+
self.get_forward_end_op_idx(self._train_program),
446447
len(self._outputs.var_ids),
447448
)
448449

@@ -642,6 +643,7 @@ def _append_backward_desc(self, main_program):
642643
if isinstance(out, framework.Variable):
643644
targets.append(program.global_block().var(out.name))
644645

646+
start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
645647
if targets:
646648
# TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
647649
core.check_and_set_prim_all_enabled()
@@ -652,12 +654,11 @@ def _append_backward_desc(self, main_program):
652654
program, start_idx = self._hooker.after_append_backward(
653655
self, program, start_idx
654656
)
655-
self._forward_end_index_map[
656-
_hash_with_id(program, self)
657-
] = start_idx - len(self._outputs.tolist())
658-
# TODO: prim make this complicate
659657
self.prepare_gradient_aggregation(start_idx, main_program, program)
660658

659+
self._forward_end_index_map[
660+
_hash_with_id(program, self)
661+
] = start_idx - len(self._outputs.tolist())
661662
return program
662663

663664
def _prune_unused_params(self, program):
@@ -1155,5 +1156,8 @@ def add_build_strategy_for(
11551156
if hasattr(compiled_program._program, 'lr_sheduler'):
11561157
builded_program.lr_sheduler = compiled_program._program.lr_sheduler
11571158
else:
1159+
# can't just create a new program, we need copy the vardesc.
11581160
builded_program = paddle.static.Program()
1161+
for var in program.block(0).vars.values():
1162+
builded_program.block(0)._clone_variable(var, False)
11591163
return builded_program

python/paddle/jit/dy2static/program_translator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1226,7 +1226,6 @@ def after_infer(self, partial_program_layer, infer_program):
12261226
partial_program.set_hooker(PrimHooker())
12271227
return concrete_program, partial_program
12281228

1229-
12301229
def __getitem__(self, item):
12311230
if not isinstance(item, CacheKey):
12321231
raise ValueError(

python/paddle/jit/dy2static/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1568,7 +1568,7 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size):
15681568
min(fwd_end_op_index + out_size, program_desc.block(0).op_size()),
15691569
):
15701570
op = program_desc.block(0).op(i)
1571-
if op.type() == 'fill_any_like':
1571+
if op.type() in ['fill_any_like', "fill_constant"]:
15721572
var_name = op.output('Out')[0]
15731573
names.append(var_name)
15741574
return names

0 commit comments

Comments
 (0)