Skip to content

Commit 77fdd7f

Browse files
committed
Unpack incrementally when needed
1 parent 6074baf commit 77fdd7f

File tree

3 files changed

+332
-290
lines changed

3 files changed

+332
-290
lines changed

neuralprophet/time_dataset.py

Lines changed: 11 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -224,17 +224,6 @@ def stack_all_features(self):
224224
)
225225
current_idx += additive_regressors_tensor.size(1)
226226

227-
if self.config_seasonality and self.config_seasonality.periods:
228-
for seasonality_name, features in self.seasonalities.items():
229-
seasonal_tensor = features
230-
print(f"Seasonality tensor shape for {seasonality_name}: {seasonal_tensor.shape}")
231-
feature_list.append(seasonal_tensor)
232-
self.feature_indices[f"seasonality_{seasonality_name}"] = (
233-
current_idx,
234-
current_idx + seasonal_tensor.size(1),
235-
)
236-
current_idx += seasonal_tensor.size(1)
237-
238227
# Stack multiplicative regressor features
239228
if self.multiplicative_regressors_names:
240229
multiplicative_regressors_tensor = torch.cat(
@@ -247,6 +236,17 @@ def stack_all_features(self):
247236
)
248237
current_idx += len(self.multiplicative_regressors_names)
249238

239+
if self.config_seasonality and self.config_seasonality.periods:
240+
for seasonality_name, features in self.seasonalities.items():
241+
seasonal_tensor = features
242+
print(f"Seasonality tensor shape for {seasonality_name}: {seasonal_tensor.shape}")
243+
feature_list.append(seasonal_tensor)
244+
self.feature_indices[f"seasonality_{seasonality_name}"] = (
245+
current_idx,
246+
current_idx + seasonal_tensor.size(1),
247+
)
248+
current_idx += seasonal_tensor.size(1)
249+
250250
# Concatenate all features into one big tensor
251251
self.all_features = torch.cat(feature_list, dim=1) # Concatenating along the third dimension
252252
if self.config_model is not None:
@@ -272,21 +272,6 @@ def compute_fourier_features(t, period):
272272
features *= condition_values
273273
self.seasonalities[name] = features
274274

275-
def get_sample_seasonalities(self, df_tensors, origin_index, n_forecasts, max_lags, n_lags, config_seasonality):
276-
seasonalities = OrderedDict({})
277-
278-
# Determine the range of indices based on whether lags are used
279-
if max_lags == 0:
280-
indices = [origin_index]
281-
else:
282-
indices = list(range(origin_index - n_lags + 1, origin_index + n_forecasts + 1))
283-
284-
# Extract the precomputed seasonalities from self.seasonalities
285-
for name, features in self.seasonalities.items():
286-
seasonalities[name] = features[indices, :]
287-
288-
return seasonalities
289-
290275
def __getitem__(self, index):
291276
"""Overrides parent class method to get an item at index.
292277
Parameters
@@ -672,83 +657,6 @@ def sort_regressor_names(self, config):
672657
multiplicative_regressors_names.append(reg)
673658
return additive_regressors_names, multiplicative_regressors_names
674659

675-
def get_sample_targets(self, df_tensors, origin_index, n_forecasts, max_lags, predict_mode):
676-
if "y_scaled" in self.df_tensors:
677-
if max_lags == 0:
678-
targets = df_tensors["y_scaled"][origin_index].unsqueeze(0).unsqueeze(1)
679-
else:
680-
targets = df_tensors["y_scaled"][origin_index + 1 : origin_index + n_forecasts + 1]
681-
targets = targets.unsqueeze(1)
682-
return targets
683-
return torch.zeros((n_forecasts, 1), dtype=torch.float32)
684-
685-
def get_sample_lagged_regressors(self, df_tensors, origin_index, config_lagged_regressors):
686-
lagged_regressors = OrderedDict({})
687-
# Future TODO: optimize this computation for many lagged_regressors
688-
for name, lagged_regressor in config_lagged_regressors.items():
689-
covar_lags = lagged_regressor.n_lags
690-
assert covar_lags > 0
691-
# Indexing tensors instead of DataFrame
692-
lagged_regressors[name] = df_tensors[name][origin_index - covar_lags + 1 : origin_index + 1]
693-
return lagged_regressors
694-
695-
def get_sample_future_regressors(
696-
self,
697-
df_tensors,
698-
origin_index,
699-
n_forecasts,
700-
max_lags,
701-
n_lags,
702-
additive_regressors_names,
703-
multiplicative_regressors_names,
704-
):
705-
regressors = OrderedDict({})
706-
if max_lags == 0:
707-
if additive_regressors_names:
708-
regressors["additive"] = df_tensors["additive_regressors"][origin_index, :].unsqueeze(0)
709-
710-
if multiplicative_regressors_names:
711-
regressors["multiplicative"] = df_tensors["multiplicative_regressors"][origin_index, :].unsqueeze(0)
712-
713-
else:
714-
if additive_regressors_names:
715-
regressors["additive"] = df_tensors["additive_regressors"][
716-
origin_index + 1 - n_lags : origin_index + n_forecasts + 1, :
717-
]
718-
if multiplicative_regressors_names:
719-
regressors["multiplicative"] = df_tensors["multiplicative_regressors"][
720-
origin_index + 1 - n_lags : origin_index + n_forecasts + 1, :
721-
]
722-
723-
return regressors
724-
725-
def get_sample_future_events(
726-
self,
727-
df_tensors,
728-
origin_index,
729-
n_forecasts,
730-
max_lags,
731-
n_lags,
732-
additive_event_and_holiday_names,
733-
multiplicative_event_and_holiday_names,
734-
):
735-
events = OrderedDict({})
736-
if max_lags == 0:
737-
if additive_event_and_holiday_names:
738-
events["additive"] = df_tensors["additive_event_and_holiday"][origin_index, :].unsqueeze(0)
739-
if multiplicative_event_and_holiday_names:
740-
events["multiplicative"] = df_tensors["multiplicative_event_and_holiday"][origin_index, :].unsqueeze(0)
741-
else:
742-
if additive_event_and_holiday_names:
743-
events["additive"] = df_tensors["additive_event_and_holiday"][
744-
origin_index + 1 - n_lags : origin_index + n_forecasts + 1, :
745-
]
746-
if multiplicative_event_and_holiday_names:
747-
events["multiplicative"] = df_tensors["multiplicative_event_and_holiday"][
748-
origin_index + 1 - n_lags : origin_index + n_forecasts + 1, :
749-
]
750-
return events
751-
752660

753661
class GlobalTimeDataset(TimeDataset):
754662
def __init__(

0 commit comments

Comments
 (0)