Skip to content

Commit dc217e5

Browse files
authored
【Hackathon 6th No.31】paddle.distribution.Normal support complex normal distribution -part (#65103)
* update distribution normal * fix test * update * update init * update test
1 parent cc00a23 commit dc217e5

File tree

3 files changed

+832
-105
lines changed

3 files changed

+832
-105
lines changed

python/paddle/distribution/distribution.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,9 @@ def _check_values_dtype_in_probs(self, param, value):
233233
Returns:
234234
value (Tensor): Change value's dtype if value's dtype is different from param.
235235
"""
236+
if paddle.is_complex(param):
237+
return value.astype(param.dtype)
238+
236239
if in_dynamic_or_pir_mode():
237240
if in_pir_mode():
238241
check_variable_and_dtype(
@@ -250,7 +253,10 @@ def _check_values_dtype_in_probs(self, param, value):
250253
return value
251254

252255
check_variable_and_dtype(
253-
value, 'value', ['float32', 'float64'], 'log_prob'
256+
value,
257+
'value',
258+
['float32', 'float64'],
259+
'log_prob',
254260
)
255261
if value.dtype != param.dtype:
256262
warnings.warn(

python/paddle/distribution/normal.py

Lines changed: 182 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class Normal(distribution.Distribution):
3030
3131
Mathematical details
3232
33-
The probability density function (pdf) is
33+
If 'loc' is real number, the probability density function (pdf) is
3434
3535
.. math::
3636
@@ -40,14 +40,24 @@ class Normal(distribution.Distribution):
4040
4141
Z = (2 \pi \sigma^2)^{0.5}
4242
43-
In the above equation:
43+
If 'loc' is complex number, the probability density function (pdf) is
44+
45+
.. math::
46+
47+
pdf(x; \mu, \sigma) = \frac{1}{Z}e^{\frac {-(x - \mu)^2} {\sigma^2} }
48+
49+
.. math::
50+
51+
Z = \pi \sigma^2
52+
53+
In the above equations:
4454
4555
* :math:`loc = \mu`: is the mean.
4656
* :math:`scale = \sigma`: is the std.
4757
* :math:`Z`: is the normalization constant.
4858
4959
Args:
50-
loc(int|float|list|tuple|numpy.ndarray|Tensor): The mean of normal distribution.The data type is float32 and float64.
60+
loc(int|float|complex|list|tuple|numpy.ndarray|Tensor): The mean of normal distribution.The data type is float32, float64, complex64 and complex128.
5161
scale(int|float|list|tuple|numpy.ndarray|Tensor): The std of normal distribution.The data type is float32 and float64.
5262
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
5363
@@ -102,6 +112,7 @@ def __init__(self, loc, scale, name=None):
102112
(
103113
int,
104114
float,
115+
complex,
105116
np.ndarray,
106117
Variable,
107118
paddle.pir.Value,
@@ -128,33 +139,82 @@ def __init__(self, loc, scale, name=None):
128139
self.all_arg_is_float = False
129140
self.name = name if name is not None else 'Normal'
130141
self.dtype = 'float32'
142+
self._complex_gaussian = False
131143

132144
if isinstance(loc, int):
133145
loc = float(loc)
134146
if isinstance(scale, int):
135147
scale = float(scale)
136148

137-
if self._validate_args(loc, scale):
138-
self.loc = loc
139-
self.scale = scale
140-
self.dtype = convert_dtype(loc.dtype)
141-
else:
142-
if isinstance(loc, float) and isinstance(scale, float):
149+
if isinstance(loc, (tuple, list)):
150+
loc = np.array(loc)
151+
if loc.dtype == np.float64:
152+
loc = loc.astype('float32')
153+
if loc.dtype == np.complex128:
154+
loc = loc.astype('complex64')
155+
156+
if isinstance(scale, (tuple, list)):
157+
scale = np.array(scale, dtype=np.float32)
158+
159+
if (
160+
isinstance(loc, complex)
161+
or (
162+
isinstance(loc, np.ndarray)
163+
and loc.dtype in [np.complex64, np.complex128]
164+
)
165+
or (self._validate_args(loc) and loc.is_complex())
166+
):
167+
self._complex_gaussian = True
168+
if isinstance(loc, complex) and isinstance(scale, float):
143169
self.all_arg_is_float = True
144-
if isinstance(loc, np.ndarray) and str(loc.dtype) in [
145-
'float32',
146-
'float64',
147-
]:
148-
self.dtype = loc.dtype
149-
elif isinstance(scale, np.ndarray) and str(scale.dtype) in [
150-
'float32',
151-
'float64',
152-
]:
153-
self.dtype = scale.dtype
154-
self.loc, self.scale = self._to_tensor(loc, scale)
155-
if self.dtype != convert_dtype(self.loc.dtype):
156-
self.loc = paddle.cast(self.loc, dtype=self.dtype)
157-
self.scale = paddle.cast(self.scale, dtype=self.dtype)
170+
171+
if isinstance(loc, np.ndarray):
172+
real_dtype = (
173+
'float32' if loc.dtype == np.complex64 else 'float64'
174+
)
175+
imag_dtype = (
176+
'float32' if loc.dtype == np.complex64 else 'float64'
177+
)
178+
real = paddle.to_tensor(loc.real, real_dtype)
179+
imag = paddle.to_tensor(loc.imag, imag_dtype)
180+
self.loc = paddle.complex(real, imag)
181+
elif isinstance(loc, complex):
182+
real = paddle.to_tensor(loc.real, dtype='float32')
183+
imag = paddle.to_tensor(loc.imag, dtype='float32')
184+
self.loc = paddle.complex(real, imag)
185+
else:
186+
self.loc = loc
187+
188+
if isinstance(scale, np.ndarray):
189+
self.scale = paddle.to_tensor(scale, dtype=scale.dtype)
190+
elif isinstance(scale, float):
191+
self.scale = paddle.to_tensor(scale, dtype='float32')
192+
else:
193+
self.scale = scale
194+
195+
self.dtype = convert_dtype(self.loc.dtype)
196+
else:
197+
if self._validate_args(loc, scale):
198+
self.loc = loc
199+
self.scale = scale
200+
self.dtype = convert_dtype(loc.dtype)
201+
else:
202+
if isinstance(loc, float) and isinstance(scale, float):
203+
self.all_arg_is_float = True
204+
if isinstance(loc, np.ndarray) and str(loc.dtype) in [
205+
'float32',
206+
'float64',
207+
]:
208+
self.dtype = loc.dtype
209+
elif isinstance(scale, np.ndarray) and str(scale.dtype) in [
210+
'float32',
211+
'float64',
212+
]:
213+
self.dtype = scale.dtype
214+
self.loc, self.scale = self._to_tensor(loc, scale)
215+
if self.dtype != convert_dtype(self.loc.dtype):
216+
self.loc = paddle.cast(self.loc, dtype=self.dtype)
217+
self.scale = paddle.cast(self.scale, dtype=self.dtype)
158218
super().__init__(self.loc.shape)
159219

160220
@property
@@ -204,15 +264,23 @@ def sample(self, shape=(), seed=0):
204264

205265
zero_tmp_shape = paddle.shape(zero_tmp_reshape)
206266
normal_random_tmp = random.gaussian(
207-
zero_tmp_shape, mean=0.0, std=1.0, seed=seed, dtype=self.dtype
267+
zero_tmp_shape,
268+
mean=(0.0 + 0.0j) if self._complex_gaussian else 0.0,
269+
std=1.0,
270+
seed=seed,
271+
dtype=self.dtype,
208272
)
209273
output = normal_random_tmp * (zero_tmp_reshape + self.scale)
210274
output = paddle.add(output, self.loc, name=name)
211275
return output
212276
else:
213277
output_shape = shape + batch_shape
214278
output = random.gaussian(
215-
output_shape, mean=0.0, std=1.0, seed=seed, dtype=self.dtype
279+
output_shape,
280+
mean=(0.0 + 0.0j) if self._complex_gaussian else 0.0,
281+
std=1.0,
282+
seed=seed,
283+
dtype=self.dtype,
216284
) * (paddle.zeros(output_shape, dtype=self.dtype) + self.scale)
217285
output = paddle.add(output, self.loc, name=name)
218286
if self.all_arg_is_float:
@@ -234,18 +302,26 @@ def rsample(self, shape=()):
234302
raise TypeError('sample shape must be Iterable object.')
235303

236304
shape = self._extend_shape(tuple(shape))
237-
eps = paddle.normal(shape=shape)
305+
eps = paddle.normal(
306+
mean=(0.0 + 0.0j) if self._complex_gaussian else 0.0, shape=shape
307+
)
238308
return self.loc + eps * self.scale
239309

240310
def entropy(self):
241311
r"""Shannon entropy in nats.
242312
243-
The entropy is
313+
If non-complex, the entropy is
244314
245315
.. math::
246316
247317
entropy(\sigma) = 0.5 \log (2 \pi e \sigma^2)
248318
319+
If complex gaussian, the entropy is
320+
321+
.. math::
322+
323+
entropy(\sigma) = \log (\pi e \sigma^2) + 1
324+
249325
In the above equation:
250326
251327
* :math:`scale = \sigma`: is the std.
@@ -256,18 +332,33 @@ def entropy(self):
256332
"""
257333
name = self.name + '_entropy'
258334
batch_shape = list((self.loc + self.scale).shape)
259-
if -1 in batch_shape:
260-
fill_shape = list(batch_shape)
261-
fill_shape[0] = paddle.shape(self.loc + self.scale)[0].item()
262-
fill_dtype = (self.loc + self.scale).dtype
263-
zero_tmp = paddle.full(fill_shape, 0.0, fill_dtype)
335+
336+
if self._complex_gaussian:
337+
if -1 in batch_shape:
338+
fill_shape = list(batch_shape)
339+
fill_shape[0] = paddle.shape(self.loc + self.scale)[0].item()
340+
fill_dtype = self.scale.dtype
341+
zero_tmp = paddle.full(fill_shape, 0.0, fill_dtype)
342+
else:
343+
zero_tmp = paddle.full(batch_shape, 0.0, self.scale.dtype)
344+
return paddle.add(
345+
1.0 + zero_tmp,
346+
math.log(math.pi) + 2.0 * paddle.log(self.scale + zero_tmp),
347+
name=name,
348+
)
264349
else:
265-
zero_tmp = paddle.full(batch_shape, 0.0, self.dtype)
266-
return paddle.add(
267-
0.5 + zero_tmp,
268-
0.5 * math.log(2 * math.pi) + paddle.log(self.scale + zero_tmp),
269-
name=name,
270-
)
350+
if -1 in batch_shape:
351+
fill_shape = list(batch_shape)
352+
fill_shape[0] = paddle.shape(self.loc + self.scale)[0].item()
353+
fill_dtype = (self.loc + self.scale).dtype
354+
zero_tmp = paddle.full(fill_shape, 0.0, fill_dtype)
355+
else:
356+
zero_tmp = paddle.full(batch_shape, 0.0, self.dtype)
357+
return paddle.add(
358+
0.5 + zero_tmp,
359+
0.5 * math.log(2 * math.pi) + paddle.log(self.scale + zero_tmp),
360+
name=name,
361+
)
271362

272363
def log_prob(self, value):
273364
"""Log probability density/mass function.
@@ -284,11 +375,18 @@ def log_prob(self, value):
284375

285376
var = self.scale * self.scale
286377
log_scale = paddle.log(self.scale)
287-
return paddle.subtract(
288-
-1.0 * ((value - self.loc) * (value - self.loc)) / (2.0 * var),
289-
log_scale + math.log(math.sqrt(2.0 * math.pi)),
290-
name=name,
291-
)
378+
if self._complex_gaussian:
379+
return paddle.subtract(
380+
-1.0 * ((value - self.loc).conj() * (value - self.loc)) / (var),
381+
2.0 * log_scale + math.log(math.pi),
382+
name=name,
383+
)
384+
else:
385+
return paddle.subtract(
386+
-1.0 * ((value - self.loc) * (value - self.loc)) / (2.0 * var),
387+
log_scale + math.log(math.sqrt(2.0 * math.pi)),
388+
name=name,
389+
)
292390

293391
def probs(self, value):
294392
"""Probability density/mass function.
@@ -304,23 +402,42 @@ def probs(self, value):
304402
value = self._check_values_dtype_in_probs(self.loc, value)
305403

306404
var = self.scale * self.scale
307-
return paddle.divide(
308-
paddle.exp(
309-
-1.0 * ((value - self.loc) * (value - self.loc)) / (2.0 * var)
310-
),
311-
(math.sqrt(2 * math.pi) * self.scale),
312-
name=name,
313-
)
405+
if self._complex_gaussian:
406+
return paddle.divide(
407+
paddle.exp(
408+
-1.0
409+
* ((value - self.loc).conj() * (value - self.loc))
410+
/ (var)
411+
),
412+
(math.pi * var),
413+
name=name,
414+
)
415+
else:
416+
return paddle.divide(
417+
paddle.exp(
418+
-1.0
419+
* ((value - self.loc) * (value - self.loc))
420+
/ (2.0 * var)
421+
),
422+
(math.sqrt(2 * math.pi) * self.scale),
423+
name=name,
424+
)
314425

315426
def kl_divergence(self, other):
316427
r"""The KL-divergence between two normal distributions.
317428
318-
The probability density function (pdf) is
429+
If non-complex, the KL-divergence is
319430
320431
.. math::
321432
322433
KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = 0.5 (ratio^2 + (\frac{diff}{\sigma_1})^2 - 1 - 2 \ln {ratio})
323434
435+
If complex gaussian:
436+
437+
.. math::
438+
439+
KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = ratio^2 + (\frac{diff}{\sigma_1})^2 - 1 - 2 \ln {ratio}
440+
324441
.. math::
325442
326443
ratio = \frac{\sigma_0}{\sigma_1}
@@ -348,11 +465,21 @@ def kl_divergence(self, other):
348465
if not in_dynamic_mode():
349466
check_type(other, 'other', Normal, 'kl_divergence')
350467

468+
if self._complex_gaussian != other._complex_gaussian:
469+
raise ValueError(
470+
"The kl divergence must be computed between two distributions in the same number field."
471+
)
351472
name = self.name + '_kl_divergence'
352473
var_ratio = self.scale / other.scale
353474
var_ratio = var_ratio * var_ratio
354475
t1 = (self.loc - other.loc) / other.scale
355-
t1 = t1 * t1
356-
return paddle.add(
357-
0.5 * var_ratio, 0.5 * (t1 - 1.0 - paddle.log(var_ratio)), name=name
358-
)
476+
if self._complex_gaussian:
477+
t1 = t1.conj() * t1
478+
return var_ratio + t1 - 1.0 - paddle.log(var_ratio)
479+
else:
480+
t1 = t1 * t1
481+
return paddle.add(
482+
0.5 * var_ratio,
483+
0.5 * (t1 - 1.0 - paddle.log(var_ratio)),
484+
name=name,
485+
)

0 commit comments

Comments
 (0)