Skip to content

Commit b64373b

Browse files
committed
updated dataset get_item
1 parent 8330bab commit b64373b

File tree

7 files changed

+367
-189
lines changed

7 files changed

+367
-189
lines changed

neuralprophet/configure.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,13 @@
2222

2323
@dataclass
2424
class Model:
25+
features_map: dict
2526
lagged_reg_layers: Optional[List[int]]
2627

2728

29+
ConfigModel = Model
30+
31+
2832
@dataclass
2933
class Normalization:
3034
normalize: str

neuralprophet/data/process.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,5 +626,6 @@ def _create_dataset(model, df, predict_mode, prediction_frequency=None):
626626
config_regressors=model.config_regressors,
627627
config_lagged_regressors=model.config_lagged_regressors,
628628
config_missing=model.config_missing,
629+
config_model=model.config_model,
629630
# config_train=model.config_train, # no longer needed since JIT tabularization.
630631
)

neuralprophet/forecaster.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from neuralprophet.plot_model_parameters_plotly import plot_parameters as plot_parameters_plotly
3636
from neuralprophet.plot_utils import get_valid_configuration, log_warning_deprecation_plotly, select_plotting_backend
3737
from neuralprophet.uncertainty import Conformal
38+
from neuralprophet.utils import unpack_sliced_tensor
3839

3940
log = logging.getLogger("NP.forecaster")
4041

@@ -487,7 +488,7 @@ def __init__(
487488
self.max_lags = self.n_lags
488489

489490
# Model
490-
self.config_model = configure.Model(lagged_reg_layers=lagged_reg_layers)
491+
self.config_model = configure.Model(features_map={}, lagged_reg_layers=lagged_reg_layers)
491492

492493
# Trend
493494
self.config_trend = configure.Trend(
@@ -1893,13 +1894,23 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5):
18931894
config_regressors=self.config_regressors,
18941895
config_lagged_regressors=self.config_lagged_regressors,
18951896
config_missing=self.config_missing,
1897+
config_model=self.config_model,
18961898
# config_train=self.config_train, # no longer needed since JIT tabularization.
18971899
)
18981900
loader = DataLoader(dataset, batch_size=min(4096, len(df)), shuffle=False, drop_last=False)
18991901
predicted = {}
19001902
for name in self.config_seasonality.periods:
19011903
predicted[name] = list()
1902-
for inputs, _, meta in loader:
1904+
for inputs_tensor, meta in loader:
1905+
inputs = unpack_sliced_tensor(
1906+
sliced_tensor=inputs_tensor,
1907+
n_lags=0,
1908+
n_forecasts=1,
1909+
max_lags=0,
1910+
feature_indices=self.config_model.features_map,
1911+
config_lagged_regressors=self.config_lagged_regressors,
1912+
config_seasonality=self.config_seasonality,
1913+
)
19031914
# Meta as a tensor for prediction
19041915
if self.model.config_seasonality is None:
19051916
meta_name_tensor = None
@@ -2631,6 +2642,7 @@ def _init_model(self):
26312642
config_events=self.config_events,
26322643
config_holidays=self.config_country_holidays,
26332644
config_normalization=self.config_normalization,
2645+
config_model=self.config_model,
26342646
n_forecasts=self.n_forecasts,
26352647
n_lags=self.n_lags,
26362648
max_lags=self.max_lags,

0 commit comments

Comments
 (0)