@@ -30,7 +30,7 @@ class Normal(distribution.Distribution):
30
30
31
31
Mathematical details
32
32
33
- The probability density function (pdf) is
33
+ If 'loc' is real number, the probability density function (pdf) is
34
34
35
35
.. math::
36
36
@@ -40,14 +40,24 @@ class Normal(distribution.Distribution):
40
40
41
41
Z = (2 \pi \sigma^2)^{0.5}
42
42
43
- In the above equation:
43
+ If 'loc' is complex number, the probability density function (pdf) is
44
+
45
+ .. math::
46
+
47
+ pdf(x; \mu, \sigma) = \frac{1}{Z}e^{\frac {-(x - \mu)^2} {\sigma^2} }
48
+
49
+ .. math::
50
+
51
+ Z = \pi \sigma^2
52
+
53
+ In the above equations:
44
54
45
55
* :math:`loc = \mu`: is the mean.
46
56
* :math:`scale = \sigma`: is the std.
47
57
* :math:`Z`: is the normalization constant.
48
58
49
59
Args:
50
- loc(int|float|list|tuple|numpy.ndarray|Tensor): The mean of normal distribution.The data type is float32 and float64 .
60
+ loc(int|float|complex| list|tuple|numpy.ndarray|Tensor): The mean of normal distribution.The data type is float32, float64, complex64 and complex128 .
51
61
scale(int|float|list|tuple|numpy.ndarray|Tensor): The std of normal distribution.The data type is float32 and float64.
52
62
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
53
63
@@ -102,6 +112,7 @@ def __init__(self, loc, scale, name=None):
102
112
(
103
113
int ,
104
114
float ,
115
+ complex ,
105
116
np .ndarray ,
106
117
Variable ,
107
118
paddle .pir .Value ,
@@ -128,33 +139,82 @@ def __init__(self, loc, scale, name=None):
128
139
self .all_arg_is_float = False
129
140
self .name = name if name is not None else 'Normal'
130
141
self .dtype = 'float32'
142
+ self ._complex_gaussian = False
131
143
132
144
if isinstance (loc , int ):
133
145
loc = float (loc )
134
146
if isinstance (scale , int ):
135
147
scale = float (scale )
136
148
137
- if self ._validate_args (loc , scale ):
138
- self .loc = loc
139
- self .scale = scale
140
- self .dtype = convert_dtype (loc .dtype )
141
- else :
142
- if isinstance (loc , float ) and isinstance (scale , float ):
149
+ if isinstance (loc , (tuple , list )):
150
+ loc = np .array (loc )
151
+ if loc .dtype == np .float64 :
152
+ loc = loc .astype ('float32' )
153
+ if loc .dtype == np .complex128 :
154
+ loc = loc .astype ('complex64' )
155
+
156
+ if isinstance (scale , (tuple , list )):
157
+ scale = np .array (scale , dtype = np .float32 )
158
+
159
+ if (
160
+ isinstance (loc , complex )
161
+ or (
162
+ isinstance (loc , np .ndarray )
163
+ and loc .dtype in [np .complex64 , np .complex128 ]
164
+ )
165
+ or (self ._validate_args (loc ) and loc .is_complex ())
166
+ ):
167
+ self ._complex_gaussian = True
168
+ if isinstance (loc , complex ) and isinstance (scale , float ):
143
169
self .all_arg_is_float = True
144
- if isinstance (loc , np .ndarray ) and str (loc .dtype ) in [
145
- 'float32' ,
146
- 'float64' ,
147
- ]:
148
- self .dtype = loc .dtype
149
- elif isinstance (scale , np .ndarray ) and str (scale .dtype ) in [
150
- 'float32' ,
151
- 'float64' ,
152
- ]:
153
- self .dtype = scale .dtype
154
- self .loc , self .scale = self ._to_tensor (loc , scale )
155
- if self .dtype != convert_dtype (self .loc .dtype ):
156
- self .loc = paddle .cast (self .loc , dtype = self .dtype )
157
- self .scale = paddle .cast (self .scale , dtype = self .dtype )
170
+
171
+ if isinstance (loc , np .ndarray ):
172
+ real_dtype = (
173
+ 'float32' if loc .dtype == np .complex64 else 'float64'
174
+ )
175
+ imag_dtype = (
176
+ 'float32' if loc .dtype == np .complex64 else 'float64'
177
+ )
178
+ real = paddle .to_tensor (loc .real , real_dtype )
179
+ imag = paddle .to_tensor (loc .imag , imag_dtype )
180
+ self .loc = paddle .complex (real , imag )
181
+ elif isinstance (loc , complex ):
182
+ real = paddle .to_tensor (loc .real , dtype = 'float32' )
183
+ imag = paddle .to_tensor (loc .imag , dtype = 'float32' )
184
+ self .loc = paddle .complex (real , imag )
185
+ else :
186
+ self .loc = loc
187
+
188
+ if isinstance (scale , np .ndarray ):
189
+ self .scale = paddle .to_tensor (scale , dtype = scale .dtype )
190
+ elif isinstance (scale , float ):
191
+ self .scale = paddle .to_tensor (scale , dtype = 'float32' )
192
+ else :
193
+ self .scale = scale
194
+
195
+ self .dtype = convert_dtype (self .loc .dtype )
196
+ else :
197
+ if self ._validate_args (loc , scale ):
198
+ self .loc = loc
199
+ self .scale = scale
200
+ self .dtype = convert_dtype (loc .dtype )
201
+ else :
202
+ if isinstance (loc , float ) and isinstance (scale , float ):
203
+ self .all_arg_is_float = True
204
+ if isinstance (loc , np .ndarray ) and str (loc .dtype ) in [
205
+ 'float32' ,
206
+ 'float64' ,
207
+ ]:
208
+ self .dtype = loc .dtype
209
+ elif isinstance (scale , np .ndarray ) and str (scale .dtype ) in [
210
+ 'float32' ,
211
+ 'float64' ,
212
+ ]:
213
+ self .dtype = scale .dtype
214
+ self .loc , self .scale = self ._to_tensor (loc , scale )
215
+ if self .dtype != convert_dtype (self .loc .dtype ):
216
+ self .loc = paddle .cast (self .loc , dtype = self .dtype )
217
+ self .scale = paddle .cast (self .scale , dtype = self .dtype )
158
218
super ().__init__ (self .loc .shape )
159
219
160
220
@property
@@ -204,15 +264,23 @@ def sample(self, shape=(), seed=0):
204
264
205
265
zero_tmp_shape = paddle .shape (zero_tmp_reshape )
206
266
normal_random_tmp = random .gaussian (
207
- zero_tmp_shape , mean = 0.0 , std = 1.0 , seed = seed , dtype = self .dtype
267
+ zero_tmp_shape ,
268
+ mean = (0.0 + 0.0j ) if self ._complex_gaussian else 0.0 ,
269
+ std = 1.0 ,
270
+ seed = seed ,
271
+ dtype = self .dtype ,
208
272
)
209
273
output = normal_random_tmp * (zero_tmp_reshape + self .scale )
210
274
output = paddle .add (output , self .loc , name = name )
211
275
return output
212
276
else :
213
277
output_shape = shape + batch_shape
214
278
output = random .gaussian (
215
- output_shape , mean = 0.0 , std = 1.0 , seed = seed , dtype = self .dtype
279
+ output_shape ,
280
+ mean = (0.0 + 0.0j ) if self ._complex_gaussian else 0.0 ,
281
+ std = 1.0 ,
282
+ seed = seed ,
283
+ dtype = self .dtype ,
216
284
) * (paddle .zeros (output_shape , dtype = self .dtype ) + self .scale )
217
285
output = paddle .add (output , self .loc , name = name )
218
286
if self .all_arg_is_float :
@@ -234,18 +302,26 @@ def rsample(self, shape=()):
234
302
raise TypeError ('sample shape must be Iterable object.' )
235
303
236
304
shape = self ._extend_shape (tuple (shape ))
237
- eps = paddle .normal (shape = shape )
305
+ eps = paddle .normal (
306
+ mean = (0.0 + 0.0j ) if self ._complex_gaussian else 0.0 , shape = shape
307
+ )
238
308
return self .loc + eps * self .scale
239
309
240
310
def entropy (self ):
241
311
r"""Shannon entropy in nats.
242
312
243
- The entropy is
313
+ If non-complex, the entropy is
244
314
245
315
.. math::
246
316
247
317
entropy(\sigma) = 0.5 \log (2 \pi e \sigma^2)
248
318
319
+ If complex gaussian, the entropy is
320
+
321
+ .. math::
322
+
323
+ entropy(\sigma) = \log (\pi e \sigma^2) + 1
324
+
249
325
In the above equation:
250
326
251
327
* :math:`scale = \sigma`: is the std.
@@ -256,18 +332,33 @@ def entropy(self):
256
332
"""
257
333
name = self .name + '_entropy'
258
334
batch_shape = list ((self .loc + self .scale ).shape )
259
- if - 1 in batch_shape :
260
- fill_shape = list (batch_shape )
261
- fill_shape [0 ] = paddle .shape (self .loc + self .scale )[0 ].item ()
262
- fill_dtype = (self .loc + self .scale ).dtype
263
- zero_tmp = paddle .full (fill_shape , 0.0 , fill_dtype )
335
+
336
+ if self ._complex_gaussian :
337
+ if - 1 in batch_shape :
338
+ fill_shape = list (batch_shape )
339
+ fill_shape [0 ] = paddle .shape (self .loc + self .scale )[0 ].item ()
340
+ fill_dtype = self .scale .dtype
341
+ zero_tmp = paddle .full (fill_shape , 0.0 , fill_dtype )
342
+ else :
343
+ zero_tmp = paddle .full (batch_shape , 0.0 , self .scale .dtype )
344
+ return paddle .add (
345
+ 1.0 + zero_tmp ,
346
+ math .log (math .pi ) + 2.0 * paddle .log (self .scale + zero_tmp ),
347
+ name = name ,
348
+ )
264
349
else :
265
- zero_tmp = paddle .full (batch_shape , 0.0 , self .dtype )
266
- return paddle .add (
267
- 0.5 + zero_tmp ,
268
- 0.5 * math .log (2 * math .pi ) + paddle .log (self .scale + zero_tmp ),
269
- name = name ,
270
- )
350
+ if - 1 in batch_shape :
351
+ fill_shape = list (batch_shape )
352
+ fill_shape [0 ] = paddle .shape (self .loc + self .scale )[0 ].item ()
353
+ fill_dtype = (self .loc + self .scale ).dtype
354
+ zero_tmp = paddle .full (fill_shape , 0.0 , fill_dtype )
355
+ else :
356
+ zero_tmp = paddle .full (batch_shape , 0.0 , self .dtype )
357
+ return paddle .add (
358
+ 0.5 + zero_tmp ,
359
+ 0.5 * math .log (2 * math .pi ) + paddle .log (self .scale + zero_tmp ),
360
+ name = name ,
361
+ )
271
362
272
363
def log_prob (self , value ):
273
364
"""Log probability density/mass function.
@@ -284,11 +375,18 @@ def log_prob(self, value):
284
375
285
376
var = self .scale * self .scale
286
377
log_scale = paddle .log (self .scale )
287
- return paddle .subtract (
288
- - 1.0 * ((value - self .loc ) * (value - self .loc )) / (2.0 * var ),
289
- log_scale + math .log (math .sqrt (2.0 * math .pi )),
290
- name = name ,
291
- )
378
+ if self ._complex_gaussian :
379
+ return paddle .subtract (
380
+ - 1.0 * ((value - self .loc ).conj () * (value - self .loc )) / (var ),
381
+ 2.0 * log_scale + math .log (math .pi ),
382
+ name = name ,
383
+ )
384
+ else :
385
+ return paddle .subtract (
386
+ - 1.0 * ((value - self .loc ) * (value - self .loc )) / (2.0 * var ),
387
+ log_scale + math .log (math .sqrt (2.0 * math .pi )),
388
+ name = name ,
389
+ )
292
390
293
391
def probs (self , value ):
294
392
"""Probability density/mass function.
@@ -304,23 +402,42 @@ def probs(self, value):
304
402
value = self ._check_values_dtype_in_probs (self .loc , value )
305
403
306
404
var = self .scale * self .scale
307
- return paddle .divide (
308
- paddle .exp (
309
- - 1.0 * ((value - self .loc ) * (value - self .loc )) / (2.0 * var )
310
- ),
311
- (math .sqrt (2 * math .pi ) * self .scale ),
312
- name = name ,
313
- )
405
+ if self ._complex_gaussian :
406
+ return paddle .divide (
407
+ paddle .exp (
408
+ - 1.0
409
+ * ((value - self .loc ).conj () * (value - self .loc ))
410
+ / (var )
411
+ ),
412
+ (math .pi * var ),
413
+ name = name ,
414
+ )
415
+ else :
416
+ return paddle .divide (
417
+ paddle .exp (
418
+ - 1.0
419
+ * ((value - self .loc ) * (value - self .loc ))
420
+ / (2.0 * var )
421
+ ),
422
+ (math .sqrt (2 * math .pi ) * self .scale ),
423
+ name = name ,
424
+ )
314
425
315
426
def kl_divergence (self , other ):
316
427
r"""The KL-divergence between two normal distributions.
317
428
318
- The probability density function (pdf) is
429
+ If non-complex, the KL-divergence is
319
430
320
431
.. math::
321
432
322
433
KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = 0.5 (ratio^2 + (\frac{diff}{\sigma_1})^2 - 1 - 2 \ln {ratio})
323
434
435
+ If complex gaussian:
436
+
437
+ .. math::
438
+
439
+ KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = ratio^2 + (\frac{diff}{\sigma_1})^2 - 1 - 2 \ln {ratio}
440
+
324
441
.. math::
325
442
326
443
ratio = \frac{\sigma_0}{\sigma_1}
@@ -348,11 +465,21 @@ def kl_divergence(self, other):
348
465
if not in_dynamic_mode ():
349
466
check_type (other , 'other' , Normal , 'kl_divergence' )
350
467
468
+ if self ._complex_gaussian != other ._complex_gaussian :
469
+ raise ValueError (
470
+ "The kl divergence must be computed between two distributions in the same number field."
471
+ )
351
472
name = self .name + '_kl_divergence'
352
473
var_ratio = self .scale / other .scale
353
474
var_ratio = var_ratio * var_ratio
354
475
t1 = (self .loc - other .loc ) / other .scale
355
- t1 = t1 * t1
356
- return paddle .add (
357
- 0.5 * var_ratio , 0.5 * (t1 - 1.0 - paddle .log (var_ratio )), name = name
358
- )
476
+ if self ._complex_gaussian :
477
+ t1 = t1 .conj () * t1
478
+ return var_ratio + t1 - 1.0 - paddle .log (var_ratio )
479
+ else :
480
+ t1 = t1 * t1
481
+ return paddle .add (
482
+ 0.5 * var_ratio ,
483
+ 0.5 * (t1 - 1.0 - paddle .log (var_ratio )),
484
+ name = name ,
485
+ )
0 commit comments