@@ -135,38 +135,41 @@ def __init__(
135135
136136 def calculate_seasonalities (self ):
137137 self .seasonalities = OrderedDict ({})
138+ self .seasonality_indices = OrderedDict ({})
138139 dates = self .df_tensors ["ds" ]
139140 t = (dates - torch .tensor (datetime (1900 , 1 , 1 ).timestamp ())).float () / (3600 * 24.0 )
140141
141- def compute_fourier_features (t , period ):
142- factor = 2.0 * np .pi / period .period
143- sin_terms = torch .sin (factor * t [:, None ] * torch .arange (1 , period .resolution + 1 ))
144- cos_terms = torch .cos (factor * t [:, None ] * torch .arange (1 , period .resolution + 1 ))
145- return torch .cat ((sin_terms , cos_terms ), dim = 1 )
142+ torch_pi = torch .tensor (np .pi , dtype = t .dtype , device = t .device )
143+ base_time = t [:, None ]
144+ start_index = 0
145+ all_features = []
146146
147147 for name , period in self .config_seasonality .periods .items ():
148148 if period .resolution > 0 :
149- features = compute_fourier_features (t , period )
149+ factor = 2.0 * torch_pi / period .period
150+ terms = factor * base_time * torch .arrange (1 , period .resolution + 1 )
151+ features = torch .cat ((torch .sin (terms ), torch .cos (terms )), dim = 1 )
150152
151153 if period .condition_name is not None :
152154 condition_values = self .df_tensors [period .condition_name ].unsqueeze (1 )
153155 features *= condition_values
154- self .seasonalities [name ] = features
155156
156- def get_sample_seasonalities (self , df_tensors , origin_index , n_forecasts , max_lags , n_lags , config_seasonality ):
157- seasonalities = OrderedDict ({})
157+ end_index = start_index + features .shape [1 ]
158+ self .seasonality_indices [name ] = (start_index , end_index )
159+ all_features .append (features )
160+ start_index = end_index
161+
162+ # Concatenate all features into a single tensor
163+ self .stacked_seasonalities = torch .cat (all_features , dim = 1 )
158164
165+ def get_sample_seasonalities (self , df_tensors , origin_index , n_forecasts , max_lags , n_lags , config_seasonality ):
159166 # Determine the range of indices based on whether lags are used
160167 if max_lags == 0 :
161168 indices = slice (origin_index , origin_index + 1 )
162169 else :
163170 indices = slice (origin_index - n_lags + 1 , origin_index + n_forecasts + 1 )
164171
165- # Extract the precomputed seasonalities from self.seasonalities
166- for name , features in self .seasonalities .items ():
167- seasonalities [name ] = features [indices , :]
168-
169- return seasonalities
172+ return self .stacked_seasonalities [indices , :]
170173
171174 def __getitem__ (self , index ):
172175 """Overrides parent class method to get an item at index.
0 commit comments