1515Functions that generate data sets used in examples 
1616""" 
1717
18+ from  typing  import  Any 
19+ 
1820import  numpy  as  np 
1921import  pandas  as  pd 
2022from  scipy .stats  import  dirichlet , gamma , norm , uniform 
2123from  statsmodels .nonparametric .smoothers_lowess  import  lowess 
2224
23- default_lowess_kwargs  =  {"frac" : 0.2 , "it" : 0 }
24- RANDOM_SEED  =  8927 
25- rng  =  np .random .default_rng (RANDOM_SEED )
25+ default_lowess_kwargs :  dict [ str ,  float ]  =  {"frac" : 0.2 , "it" : 0 }
26+ RANDOM_SEED :  int  =  8927 
27+ rng :  np . random . Generator  =  np .random .default_rng (RANDOM_SEED )
2628
2729
2830def  _smoothed_gaussian_random_walk (
29-     gaussian_random_walk_mu , gaussian_random_walk_sigma , N , lowess_kwargs 
30- ):
31+     gaussian_random_walk_mu : float ,
32+     gaussian_random_walk_sigma : float ,
33+     N : int ,
34+     lowess_kwargs : dict [str , Any ],
35+ ) ->  tuple [np .ndarray , np .ndarray ]:
3136    """ 
3237    Generates Gaussian random walk data and applies LOWESS 
3338
@@ -48,12 +53,12 @@ def _smoothed_gaussian_random_walk(
4853
4954
5055def  generate_synthetic_control_data (
51-     N = 100 ,
52-     treatment_time = 70 ,
53-     grw_mu = 0.25 ,
54-     grw_sigma = 1 ,
55-     lowess_kwargs = default_lowess_kwargs ,
56- ):
56+     N :  int   =   100 ,
57+     treatment_time :  int   =   70 ,
58+     grw_mu :  float   =   0.25 ,
59+     grw_sigma :  float   =   1 ,
60+     lowess_kwargs :  dict [ str ,  Any ]  |   None   =   None ,
61+ )  ->   tuple [ pd . DataFrame ,  np . ndarray ] :
5762    """ 
5863    Generates data for synthetic control example. 
5964
@@ -73,6 +78,8 @@ def generate_synthetic_control_data(
7378    >>> from causalpy.data.simulate_data import generate_synthetic_control_data 
7479    >>> df, weightings_true = generate_synthetic_control_data(treatment_time=70) 
7580    """ 
81+     if  lowess_kwargs  is  None :
82+         lowess_kwargs  =  default_lowess_kwargs 
7683
7784    # 1. Generate non-treated variables 
7885    df  =  pd .DataFrame (
@@ -108,8 +115,12 @@ def generate_synthetic_control_data(
108115
109116
110117def  generate_time_series_data (
111-     N = 100 , treatment_time = 70 , beta_temp = - 1 , beta_linear = 0.5 , beta_intercept = 3 
112- ):
118+     N : int  =  100 ,
119+     treatment_time : int  =  70 ,
120+     beta_temp : float  =  - 1 ,
121+     beta_linear : float  =  0.5 ,
122+     beta_intercept : float  =  3 ,
123+ ) ->  pd .DataFrame :
113124    """ 
114125    Generates interrupted time series example data 
115126
@@ -155,7 +166,7 @@ def generate_time_series_data(
155166    return  df 
156167
157168
158- def  generate_time_series_data_seasonal (treatment_time ) :
169+ def  generate_time_series_data_seasonal (treatment_time :  pd . Timestamp )  ->   pd . DataFrame :
159170    """ 
160171    Generates 10 years of monthly data with seasonality 
161172    """ 
@@ -183,7 +194,9 @@ def generate_time_series_data_seasonal(treatment_time):
183194    return  df 
184195
185196
186- def  generate_time_series_data_simple (treatment_time , slope = 0.0 ):
197+ def  generate_time_series_data_simple (
198+     treatment_time : pd .Timestamp , slope : float  =  0.0 
199+ ) ->  pd .DataFrame :
187200    """Generate simple interrupted time series data, with no seasonality or temporal 
188201    structure. 
189202    """ 
@@ -205,7 +218,7 @@ def generate_time_series_data_simple(treatment_time, slope=0.0):
205218    return  df 
206219
207220
208- def  generate_did ():
221+ def  generate_did ()  ->   pd . DataFrame :
209222    """ 
210223    Generate Difference in Differences data 
211224
@@ -223,8 +236,14 @@ def generate_did():
223236
224237    # local functions 
225238    def  outcome (
226-         t , control_intercept , treat_intercept_delta , trend , Δ , group , post_treatment 
227-     ):
239+         t : np .ndarray ,
240+         control_intercept : float ,
241+         treat_intercept_delta : float ,
242+         trend : float ,
243+         Δ : float ,
244+         group : np .ndarray ,
245+         post_treatment : np .ndarray ,
246+     ) ->  np .ndarray :
228247        """Compute the outcome of each unit""" 
229248        return  (
230249            control_intercept 
@@ -257,8 +276,8 @@ def outcome(
257276
258277
259278def  generate_regression_discontinuity_data (
260-     N = 100 , true_causal_impact = 0.5 , true_treatment_threshold = 0.0 
261- ):
279+     N :  int   =   100 , true_causal_impact :  float   =   0.5 , true_treatment_threshold :  float   =   0.0 
280+ )  ->   pd . DataFrame :
262281    """ 
263282    Generate regression discontinuity example data 
264283
@@ -272,12 +291,12 @@ def generate_regression_discontinuity_data(
272291    ... )  # doctest: +SKIP 
273292    """ 
274293
275-     def  is_treated (x ) :
294+     def  is_treated (x :  np . ndarray )  ->   np . ndarray :
276295        """Check if x was treated""" 
277296        return  np .greater_equal (x , true_treatment_threshold )
278297
279-     def  impact (x ) :
280-         """Assign true_causal_impact to all treaated  entries""" 
298+     def  impact (x :  np . ndarray )  ->   np . ndarray :
299+         """Assign true_causal_impact to all treated  entries""" 
281300        y  =  np .zeros (len (x ))
282301        y [is_treated (x )] =  true_causal_impact 
283302        return  y 
@@ -289,8 +308,11 @@ def impact(x):
289308
290309
291310def  generate_ancova_data (
292-     N = 200 , pre_treatment_means = np .array ([10 , 12 ]), treatment_effect = 2 , sigma = 1 
293- ):
311+     N : int  =  200 ,
312+     pre_treatment_means : np .ndarray  =  np .array ([10 , 12 ]),
313+     treatment_effect : float  =  2 ,
314+     sigma : float  =  1 ,
315+ ) ->  pd .DataFrame :
294316    """ 
295317    Generate ANCOVA example data 
296318
@@ -310,7 +332,7 @@ def generate_ancova_data(
310332    return  df 
311333
312334
313- def  generate_geolift_data ():
335+ def  generate_geolift_data ()  ->   pd . DataFrame :
314336    """Generate synthetic data for a geolift example. This will consists of 6 untreated 
315337    countries. The treated unit `Denmark` is a weighted combination of the untreated 
316338    units. We additionally specify a treatment effect which takes effect after the 
@@ -360,7 +382,7 @@ def generate_geolift_data():
360382    return  df 
361383
362384
363- def  generate_multicell_geolift_data ():
385+ def  generate_multicell_geolift_data ()  ->   pd . DataFrame :
364386    """Generate synthetic data for a geolift example. This will consists of 6 untreated 
365387    countries. The treated unit `Denmark` is a weighted combination of the untreated 
366388    units. We additionally specify a treatment effect which takes effect after the 
@@ -422,7 +444,9 @@ def generate_multicell_geolift_data():
422444# ----------------- 
423445
424446
425- def  generate_seasonality (n = 12 , amplitude = 1 , length_scale = 0.5 ):
447+ def  generate_seasonality (
448+     n : int  =  12 , amplitude : float  =  1 , length_scale : float  =  0.5 
449+ ) ->  np .ndarray :
426450    """Generate monthly seasonality by sampling from a Gaussian process with a 
427451    Gaussian kernel, using numpy code""" 
428452    # Generate the covariance matrix 
@@ -436,14 +460,26 @@ def generate_seasonality(n=12, amplitude=1, length_scale=0.5):
436460    return  seasonality 
437461
438462
439- def  periodic_kernel (x1 , x2 , period = 1 , length_scale = 1 , amplitude = 1 ):
463+ def  periodic_kernel (
464+     x1 : np .ndarray ,
465+     x2 : np .ndarray ,
466+     period : float  =  1 ,
467+     length_scale : float  =  1 ,
468+     amplitude : float  =  1 ,
469+ ) ->  np .ndarray :
440470    """Generate a periodic kernel for gaussian process""" 
441471    return  amplitude ** 2  *  np .exp (
442472        - 2  *  np .sin (np .pi  *  np .abs (x1  -  x2 ) /  period ) **  2  /  length_scale ** 2 
443473    )
444474
445475
446- def  create_series (n = 52 , amplitude = 1 , length_scale = 2 , n_years = 4 , intercept = 3 ):
476+ def  create_series (
477+     n : int  =  52 ,
478+     amplitude : float  =  1 ,
479+     length_scale : float  =  2 ,
480+     n_years : int  =  4 ,
481+     intercept : float  =  3 ,
482+ ) ->  np .ndarray :
447483    """ 
448484    Returns numpy tile with generated seasonality data repeated over 
449485    multiple years 
0 commit comments