Skip to content

Commit 5270dff

Browse files
committed
stack seasonalities
1 parent 1ca2085 commit 5270dff

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

neuralprophet/time_dataset.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -135,38 +135,41 @@ def __init__(
135135

136136
def calculate_seasonalities(self):
137137
self.seasonalities = OrderedDict({})
138+
self.seasonality_indices = OrderedDict({})
138139
dates = self.df_tensors["ds"]
139140
t = (dates - torch.tensor(datetime(1900, 1, 1).timestamp())).float() / (3600 * 24.0)
140141

141-
def compute_fourier_features(t, period):
142-
factor = 2.0 * np.pi / period.period
143-
sin_terms = torch.sin(factor * t[:, None] * torch.arange(1, period.resolution + 1))
144-
cos_terms = torch.cos(factor * t[:, None] * torch.arange(1, period.resolution + 1))
145-
return torch.cat((sin_terms, cos_terms), dim=1)
142+
torch_pi = torch.tensor(np.pi, dtype=t.dtype, device=t.device)
143+
base_time = t[:, None]
144+
start_index = 0
145+
all_features = []
146146

147147
for name, period in self.config_seasonality.periods.items():
148148
if period.resolution > 0:
149-
features = compute_fourier_features(t, period)
149+
factor = 2.0 * torch_pi / period.period
150+
terms = factor * base_time * torch.arrange(1, period.resolution + 1)
151+
features = torch.cat((torch.sin(terms), torch.cos(terms)), dim=1)
150152

151153
if period.condition_name is not None:
152154
condition_values = self.df_tensors[period.condition_name].unsqueeze(1)
153155
features *= condition_values
154-
self.seasonalities[name] = features
155156

156-
def get_sample_seasonalities(self, df_tensors, origin_index, n_forecasts, max_lags, n_lags, config_seasonality):
157-
seasonalities = OrderedDict({})
157+
end_index = start_index + features.shape[1]
158+
self.seasonality_indices[name] = (start_index, end_index)
159+
all_features.append(features)
160+
start_index = end_index
161+
162+
# Concatenate all features into a single tensor
163+
self.stacked_seasonalities = torch.cat(all_features, dim=1)
158164

165+
def get_sample_seasonalities(self, df_tensors, origin_index, n_forecasts, max_lags, n_lags, config_seasonality):
159166
# Determine the range of indices based on whether lags are used
160167
if max_lags == 0:
161168
indices = slice(origin_index, origin_index + 1)
162169
else:
163170
indices = slice(origin_index - n_lags + 1, origin_index + n_forecasts + 1)
164171

165-
# Extract the precomputed seasonalities from self.seasonalities
166-
for name, features in self.seasonalities.items():
167-
seasonalities[name] = features[indices, :]
168-
169-
return seasonalities
172+
return self.stacked_seasonalities[indices, :]
170173

171174
def __getitem__(self, index):
172175
"""Overrides parent class method to get an item at index.

0 commit comments

Comments
 (0)