Skip to content

Commit

Permalink
【PIR API adaptor No.31】Migrate paddle.distribution.Normal into pir (P…
Browse files Browse the repository at this point in the history
…addlePaddle#59910)

* fix

* update

* update

* fix
  • Loading branch information
ooooo-create authored Jan 16, 2024
1 parent 5b3bc50 commit d431622
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 42 deletions.
16 changes: 14 additions & 2 deletions python/paddle/distribution/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from paddle import _C_ops
from paddle.base.data_feeder import check_variable_and_dtype, convert_dtype
from paddle.base.framework import Variable
from paddle.framework import in_dynamic_mode
from paddle.framework import in_dynamic_mode, in_pir_mode


class Distribution:
Expand Down Expand Up @@ -185,6 +185,10 @@ def _to_tensor(self, *args):
type(arg)
)
)
if isinstance(arg, paddle.pir.Value):
# pir.Value does not need to be converted to numpy.ndarray, so we skip here
numpy_args.append(arg)
continue

arg_np = np.array(arg)
arg_dtype = arg_np.dtype
Expand All @@ -202,8 +206,16 @@ def _to_tensor(self, *args):

dtype = tmp.dtype
for arg in numpy_args:
if isinstance(arg, paddle.pir.Value):
# pir.Value does not need to be converted to numpy.ndarray, so we skip here
variable_args.append(arg)
continue

arg_broadcasted, _ = np.broadcast_arrays(arg, tmp)
arg_variable = paddle.tensor.create_tensor(dtype=dtype)
if in_pir_mode():
arg_variable = paddle.zeros(arg_broadcasted.shape)
else:
arg_variable = paddle.tensor.create_tensor(dtype=dtype)
paddle.assign(arg_broadcasted, arg_variable)
variable_args.append(arg_variable)

Expand Down
194 changes: 154 additions & 40 deletions test/distribution/test_distribution_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,29 @@
np.random.seed(2022)


class InitDataContextManager:
def __init__(self, in_pir, prog):
self.in_pir = in_pir
self.prog = prog

def __enter__(self):
if self.in_pir:
self.guard = paddle.pir_utils.IrGuard()
self.guard.__enter__()
self.program_guard = paddle.static.program_guard(self.prog)
self.program_guard.__enter__()
else:
self.program_guard = base.program_guard(self.prog)
self.program_guard.__enter__()

def __exit__(self, exc_type, exc_value, traceback):
if self.in_pir:
self.program_guard.__exit__(exc_type, exc_value, traceback)
self.guard.__exit__(exc_type, exc_value, traceback)
else:
self.program_guard.__exit__(exc_type, exc_value, traceback)


class NormalNumpy(DistributionNumpy):
def __init__(self, loc, scale):
self.loc = np.array(loc)
Expand Down Expand Up @@ -80,15 +103,19 @@ def setUp(self, use_gpu=False, batch_size=2, dims=3):
self.place = base.CUDAPlace(0)
self.gpu_id = 0

self.init_numpy_data(batch_size, dims)
self.batch_size = batch_size
self.dims = dims
self.init_numpy_data(self.batch_size, self.dims)

paddle.disable_static(self.place)
self.init_dynamic_data(batch_size, dims)
self.init_dynamic_data(self.batch_size, self.dims)

paddle.enable_static()
self.test_program = base.Program()
with paddle.pir_utils.IrGuard():
self.test_pir_program = paddle.static.Program()

self.executor = base.Executor(self.place)
self.init_static_data(batch_size, dims)

def init_numpy_data(self, batch_size, dims):
# loc ans scale are 'float'
Expand All @@ -110,12 +137,15 @@ def init_dynamic_data(self, batch_size, dims):
self.dynamic_other_scale = self.other_scale_np
self.dynamic_values = paddle.to_tensor(self.values_np)

def init_static_data(self, batch_size, dims):
def init_static_data(self, batch_size, dims, in_pir=False):
self.static_loc = self.loc_np
self.static_scale = self.scale_np
self.static_other_loc = self.other_loc_np
self.static_other_scale = self.other_scale_np
with base.program_guard(self.test_program):
manager = InitDataContextManager(
in_pir, self.test_pir_program if in_pir else self.test_program
)
with manager as mgr:
self.static_values = paddle.static.data(
name='values', shape=[-1], dtype='float32'
)
Expand Down Expand Up @@ -165,9 +195,8 @@ def test_normal_distribution_dygraph(self, sample_shape=7, tolerance=1e-6):
fetch_list = [sample, entropy, log_prob, probs, kl]
self.compare_with_numpy(fetch_list)

def test_normal_distribution_static(self, sample_shape=7, tolerance=1e-6):
paddle.enable_static()
with base.program_guard(self.test_program):
def run_old_ir_normal_distribution_static(self, sample_shape):
with base.program_guard(self.test_program, paddle.static.Program()):
normal = Normal(self.static_loc, self.static_scale)

sample = normal.sample([sample_shape])
Expand All @@ -181,20 +210,62 @@ def test_normal_distribution_static(self, sample_shape=7, tolerance=1e-6):

fetch_list = [sample, entropy, log_prob, probs, kl]

feed_vars = {
'loc': self.loc_np,
'scale': self.scale_np,
'values': self.values_np,
'other_loc': self.other_loc_np,
'other_scale': self.other_scale_np,
}
feed_vars = {
'loc': self.loc_np,
'scale': self.scale_np,
'values': self.values_np,
'other_loc': self.other_loc_np,
'other_scale': self.other_scale_np,
}

self.executor.run(base.default_startup_program())
fetch_list = self.executor.run(
program=self.test_program, feed=feed_vars, fetch_list=fetch_list
)

self.compare_with_numpy(fetch_list)

def run_pir_normal_distribution_static(self, sample_shape):
with paddle.pir_utils.IrGuard():
with paddle.static.program_guard(
self.test_pir_program, paddle.static.Program()
):
normal = Normal(self.static_loc, self.static_scale)

sample = normal.sample([sample_shape])
entropy = normal.entropy()
log_prob = normal.log_prob(self.static_values)
probs = normal.probs(self.static_values)
other_normal = Normal(
self.static_other_loc, self.static_other_scale
)
kl = normal.kl_divergence(other_normal)

fetch_list = [sample, entropy, log_prob, probs, kl]

feed_vars = {
'loc': self.loc_np,
'scale': self.scale_np,
'values': self.values_np,
'other_loc': self.other_loc_np,
'other_scale': self.other_scale_np,
}
self.executor.run(paddle.static.default_startup_program())
fetch_list = self.executor.run(
program=self.test_pir_program,
feed=feed_vars,
fetch_list=fetch_list,
)

self.compare_with_numpy(fetch_list)

self.executor.run(base.default_startup_program())
fetch_list = self.executor.run(
program=self.test_program, feed=feed_vars, fetch_list=fetch_list
)
def test_normal_distribution_static(self, sample_shape=7, tolerance=1e-6):
paddle.enable_static()
self.init_static_data(self.batch_size, self.dims, in_pir=False)
self.run_old_ir_normal_distribution_static(sample_shape)

self.compare_with_numpy(fetch_list)
self.init_static_data(self.batch_size, self.dims, in_pir=True)
self.run_pir_normal_distribution_static(sample_shape)


class NormalTest2(NormalTest):
Expand Down Expand Up @@ -230,12 +301,15 @@ def init_numpy_data(self, batch_size, dims):
'float32'
)

def init_static_data(self, batch_size, dims):
def init_static_data(self, batch_size, dims, in_pir=False):
self.static_loc = self.loc_np
self.static_scale = self.scale_np
self.static_other_loc = self.other_loc_np
self.static_other_scale = self.other_scale_np
with base.program_guard(self.test_program):
manager = InitDataContextManager(
in_pir, self.test_pir_program if in_pir else self.test_program
)
with manager as mgr:
self.static_values = paddle.static.data(
name='values', shape=[-1, dims], dtype='float32'
)
Expand All @@ -259,12 +333,15 @@ def init_numpy_data(self, batch_size, dims):
'float32'
)

def init_static_data(self, batch_size, dims):
def init_static_data(self, batch_size, dims, in_pir=False):
self.static_loc = self.loc_np
self.static_scale = self.scale_np
self.static_other_loc = self.other_loc_np
self.static_other_scale = self.other_scale_np
with base.program_guard(self.test_program):
manager = InitDataContextManager(
in_pir, self.test_pir_program if in_pir else self.test_program
)
with manager as mgr:
self.static_values = paddle.static.data(
name='values', shape=[-1, dims], dtype='float32'
)
Expand Down Expand Up @@ -295,12 +372,16 @@ def init_dynamic_data(self, batch_size, dims):
self.dynamic_other_scale = self.other_scale_np
self.dynamic_values = paddle.to_tensor(self.values_np, dtype='float64')

def init_static_data(self, batch_size, dims):
def init_static_data(self, batch_size, dims, in_pir=False):
self.static_loc = self.loc_np
self.static_scale = self.scale_np
self.static_other_loc = self.other_loc_np
self.static_other_scale = self.other_scale_np
with base.program_guard(self.test_program):

manager = InitDataContextManager(
in_pir, self.test_pir_program if in_pir else self.test_program
)
with manager as mgr:
self.static_values = paddle.static.data(
name='values', shape=[-1, dims], dtype='float64'
)
Expand Down Expand Up @@ -331,8 +412,11 @@ def init_dynamic_data(self, batch_size, dims):
self.dynamic_other_loc = paddle.to_tensor(self.other_loc_np)
self.dynamic_other_scale = paddle.to_tensor(self.other_scale_np)

def init_static_data(self, batch_size, dims):
with base.program_guard(self.test_program):
def init_static_data(self, batch_size, dims, in_pir=False):
manager = InitDataContextManager(
in_pir, self.test_pir_program if in_pir else self.test_program
)
with manager as mgr:
self.static_loc = paddle.static.data(
name='loc', shape=[-1, dims], dtype='float32'
)
Expand Down Expand Up @@ -379,8 +463,11 @@ def init_dynamic_data(self, batch_size, dims):
self.other_scale_np, dtype='float64'
)

def init_static_data(self, batch_size, dims):
with base.program_guard(self.test_program):
def init_static_data(self, batch_size, dims, in_pir=False):
manager = InitDataContextManager(
in_pir, self.test_pir_program if in_pir else self.test_program
)
with manager as mgr:
self.static_loc = paddle.static.data(
name='loc', shape=[-1, dims], dtype='float64'
)
Expand Down Expand Up @@ -427,8 +514,11 @@ def init_dynamic_data(self, batch_size, dims):
self.other_scale_np, dtype='float64'
)

def init_static_data(self, batch_size, dims):
with base.program_guard(self.test_program):
def init_static_data(self, batch_size, dims, in_pir=False):
manager = InitDataContextManager(
in_pir, self.test_pir_program if in_pir else self.test_program
)
with manager as mgr:
self.static_loc = paddle.static.data(
name='loc', shape=[-1, dims], dtype='float64'
)
Expand Down Expand Up @@ -470,12 +560,15 @@ def init_numpy_data(self, batch_size, dims):
)
self.other_scale_np = self.other_scale_np.tolist()

def init_static_data(self, batch_size, dims):
def init_static_data(self, batch_size, dims, in_pir=False):
self.static_loc = self.loc_np
self.static_scale = self.scale_np
self.static_other_loc = self.other_loc_np
self.static_other_scale = self.other_scale_np
with base.program_guard(self.test_program):
manager = InitDataContextManager(
in_pir, self.test_pir_program if in_pir else self.test_program
)
with manager as mgr:
self.static_values = paddle.static.data(
name='values', shape=[-1, dims], dtype='float32'
)
Expand Down Expand Up @@ -505,12 +598,15 @@ def init_numpy_data(self, batch_size, dims):
)
self.other_scale_np = tuple(self.other_scale_np.tolist())

def init_static_data(self, batch_size, dims):
def init_static_data(self, batch_size, dims, in_pir=False):
self.static_loc = self.loc_np
self.static_scale = self.scale_np
self.static_other_loc = self.other_loc_np
self.static_other_scale = self.other_scale_np
with base.program_guard(self.test_program):
manager = InitDataContextManager(
in_pir, self.test_pir_program if in_pir else self.test_program
)
with manager as mgr:
self.static_values = paddle.static.data(
name='values', shape=[-1, dims], dtype='float32'
)
Expand Down Expand Up @@ -559,10 +655,12 @@ def test_sample(self):

@place(config.DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand((4,)), xrand((4,)))]
(TEST_CASE_NAME, 'loc', 'scale'),
[('sample', xrand((4,)), xrand((4,)))],
test_pir=True,
)
class TestNormalSampleStaic(unittest.TestCase):
def setUp(self):
def build_program(self):
paddle.enable_static()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
Expand All @@ -586,6 +684,13 @@ def setUp(self):
main_program, feed=self.feeds, fetch_list=fetch_list
)

def setUp(self):
if self.test_pir:
with paddle.pir_utils.IrGuard():
self.build_program()
else:
self.build_program()

def test_sample(self):
samples_mean = self.samples.mean(axis=0)
samples_var = self.samples.var(axis=0)
Expand Down Expand Up @@ -646,10 +751,12 @@ def test_backpropagation(self):

@place(config.DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'loc', 'scale'), [('rsample', xrand((4,)), xrand((4,)))]
(TEST_CASE_NAME, 'loc', 'scale'),
[('rsample', xrand((4,)), xrand((4,)))],
test_pir=True,
)
class TestNormalRSampleStaic(unittest.TestCase):
def setUp(self):
def build_program(self):
paddle.enable_static()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
Expand All @@ -673,6 +780,13 @@ def setUp(self):
main_program, feed=self.feeds, fetch_list=fetch_list
)

def setUp(self):
if self.test_pir:
with paddle.pir_utils.IrGuard():
self.build_program()
else:
self.build_program()

def test_rsample(self):
rsamples_mean = self.rsamples.mean(axis=0)
rsamples_var = self.rsamples.var(axis=0)
Expand Down

0 comments on commit d431622

Please sign in to comment.