Skip to content

Commit ccc07cc

Browse files
authored
Added named_arrays.random.gamma() function to draw samples from the gamma distribution. (#106)
1 parent 9c97ba1 commit ccc07cc

File tree

5 files changed

+140
-0
lines changed

5 files changed

+140
-0
lines changed

named_arrays/_scalars/scalar_named_array_functions.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,56 @@ def random_binomial(
625625
)
626626

627627

628+
@_implements(na.random.gamma)
629+
def random_gamma(
630+
shape: float | na.AbstractScalarArray,
631+
scale: float | u.Quantity | na.AbstractScalarArray = 1,
632+
shape_random: None | dict[str, int] = None,
633+
seed: None | int = None,
634+
) -> na.ScalarArray:
635+
alpha = shape
636+
theta = scale
637+
638+
try:
639+
alpha = scalars._normalize(alpha)
640+
theta = scalars._normalize(theta)
641+
except na.ScalarTypeError:
642+
return NotImplemented
643+
644+
if shape_random is None:
645+
shape_random = dict()
646+
647+
shape_base = na.shape_broadcasted(alpha, theta)
648+
shape = na.broadcast_shapes(shape_base, shape_random)
649+
650+
alpha = alpha.ndarray_aligned(shape)
651+
theta = theta.ndarray_aligned(shape)
652+
653+
unit = na.unit(theta)
654+
655+
if unit is not None:
656+
theta = theta.value
657+
658+
if seed is None:
659+
func = np.random.gamma
660+
else:
661+
func = np.random.default_rng(seed).gamma
662+
663+
value = func(
664+
shape=alpha,
665+
scale=theta,
666+
size=tuple(shape.values()),
667+
)
668+
669+
if unit is not None:
670+
value = value << unit
671+
672+
return na.ScalarArray(
673+
ndarray=value,
674+
axes=tuple(shape.keys()),
675+
)
676+
677+
628678
def plt_plot_like(
629679
func: Callable,
630680
*args: na.AbstractScalarArray,

named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
na.random.normal,
2525
na.random.poisson,
2626
na.random.binomial,
27+
na.random.gamma,
2728
)
2829
PLT_PLOT_LIKE_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.PLT_PLOT_LIKE_FUNCTIONS
2930
NDFILTER_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.NDFILTER_FUNCTIONS

named_arrays/_vectors/vector_named_array_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
na.random.normal,
2828
na.random.poisson,
2929
na.random.binomial,
30+
na.random.gamma,
3031
)
3132
PLT_PLOT_LIKE_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.PLT_PLOT_LIKE_FUNCTIONS
3233
NDFILTER_FUNCTIONS = named_arrays._scalars.scalar_named_array_functions.NDFILTER_FUNCTIONS

named_arrays/random.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"normal",
1111
"poisson",
1212
"binomial",
13+
"gamma",
1314
]
1415

1516

@@ -19,6 +20,8 @@
1920
RandomWidthT = TypeVar("RandomWidthT", bound="float | complex | u.Quantity | na.AbstractArray")
2021
NumTrialsT = TypeVar("NumTrialsT", bound="int | na.AbstractArray")
2122
ProbabilityT = TypeVar("ProbabilityT", bound="float | na.AbstractArray")
23+
ShapeT = TypeVar("ShapeT", bound="float | na.AbstractArray")
24+
ScaleT = TypeVar("ScaleT", bound="float | u.Quantity | na.AbstractArray")
2225

2326

2427
def uniform(
@@ -164,3 +167,37 @@ def binomial(
164167
shape_random=shape_random,
165168
seed=seed,
166169
)
170+
171+
172+
def gamma(
173+
shape: ShapeT,
174+
scale: ScaleT = 1,
175+
shape_random: None | dict[str, int] = None,
176+
seed: None | int = None,
177+
) -> ShapeT | ScaleT:
178+
"""
179+
Draw samples from a gamma distribution.
180+
181+
Parameters
182+
----------
183+
shape
184+
The shape parameter of the distribution.
185+
scale
186+
The scale parameter of the distribution.
187+
shape_random
188+
Additional dimensions to be broadcast against `shape` and `scale`.
189+
seed
190+
Optional seed for the random number generator,
191+
can be provided for repeatability.
192+
193+
See Also
194+
--------
195+
:func:`numpy.random.gamma` : Equivalent numpy function
196+
"""
197+
return na._named_array_function(
198+
func=gamma,
199+
shape=na.as_named_array(shape),
200+
scale=scale,
201+
shape_random=shape_random,
202+
seed=seed,
203+
)

named_arrays/tests/test_random.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,54 @@ def test_binomial(
5555

5656
assert np.all(result >= 0)
5757
assert np.all(result <= n)
58+
59+
60+
@pytest.mark.parametrize(
61+
argnames="shape",
62+
argvalues=[
63+
0.5,
64+
na.ScalarArray(0.51),
65+
na.linspace(0.4, 0.5, axis="p", num=5),
66+
na.UniformUncertainScalarArray(0.5, width=0.1),
67+
na.Cartesian2dVectorArray(0.5, 0.6),
68+
],
69+
)
70+
@pytest.mark.parametrize(
71+
argnames="scale",
72+
argvalues=[
73+
10,
74+
(11 * u.photon).astype(int),
75+
na.ScalarArray(12),
76+
(na.arange(1, 10, axis="x") << u.photon).astype(int),
77+
na.Cartesian2dVectorArray(10, 11),
78+
],
79+
)
80+
@pytest.mark.parametrize(
81+
argnames="shape_random",
82+
argvalues=[
83+
None,
84+
dict(_s=6),
85+
],
86+
)
87+
@pytest.mark.parametrize(
88+
argnames="seed",
89+
argvalues=[
90+
None,
91+
42,
92+
],
93+
)
94+
def test_gamma(
95+
shape: float | na.AbstractScalar | na.AbstractVectorArray,
96+
scale: float | u.Quantity | na.AbstractScalar | na.AbstractVectorArray,
97+
shape_random: None | dict[str, int],
98+
seed: None | int,
99+
):
100+
result = na.random.gamma(
101+
shape=shape,
102+
scale=scale,
103+
shape_random=shape_random,
104+
seed=seed,
105+
)
106+
107+
assert na.unit(result) == na.unit(scale)
108+
assert np.all(result >= 0)

0 commit comments

Comments
 (0)