|
35 | 35 | from neuralprophet.plot_model_parameters_plotly import plot_parameters as plot_parameters_plotly |
36 | 36 | from neuralprophet.plot_utils import get_valid_configuration, log_warning_deprecation_plotly, select_plotting_backend |
37 | 37 | from neuralprophet.uncertainty import Conformal |
| 38 | +from neuralprophet.utils import unpack_sliced_tensor |
38 | 39 |
|
39 | 40 | log = logging.getLogger("NP.forecaster") |
40 | 41 |
|
@@ -487,7 +488,7 @@ def __init__( |
487 | 488 | self.max_lags = self.n_lags |
488 | 489 |
|
489 | 490 | # Model |
490 | | - self.config_model = configure.Model(lagged_reg_layers=lagged_reg_layers) |
| 491 | + self.config_model = configure.Model(features_map={}, lagged_reg_layers=lagged_reg_layers) |
491 | 492 |
|
492 | 493 | # Trend |
493 | 494 | self.config_trend = configure.Trend( |
@@ -1893,13 +1894,23 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5): |
1893 | 1894 | config_regressors=self.config_regressors, |
1894 | 1895 | config_lagged_regressors=self.config_lagged_regressors, |
1895 | 1896 | config_missing=self.config_missing, |
| 1897 | + config_model=self.config_model, |
1896 | 1898 | # config_train=self.config_train, # no longer needed since JIT tabularization. |
1897 | 1899 | ) |
1898 | 1900 | loader = DataLoader(dataset, batch_size=min(4096, len(df)), shuffle=False, drop_last=False) |
1899 | 1901 | predicted = {} |
1900 | 1902 | for name in self.config_seasonality.periods: |
1901 | 1903 | predicted[name] = list() |
1902 | | - for inputs, _, meta in loader: |
| 1904 | + for inputs_tensor, meta in loader: |
| 1905 | + inputs = unpack_sliced_tensor( |
| 1906 | + sliced_tensor=inputs_tensor, |
| 1907 | + n_lags=0, |
| 1908 | + n_forecasts=1, |
| 1909 | + max_lags=0, |
| 1910 | + feature_indices=self.config_model.features_map, |
| 1911 | + config_lagged_regressors=self.config_lagged_regressors, |
| 1912 | + config_seasonality=self.config_seasonality, |
| 1913 | + ) |
1903 | 1914 | # Meta as a tensor for prediction |
1904 | 1915 | if self.model.config_seasonality is None: |
1905 | 1916 | meta_name_tensor = None |
@@ -2631,6 +2642,7 @@ def _init_model(self): |
2631 | 2642 | config_events=self.config_events, |
2632 | 2643 | config_holidays=self.config_country_holidays, |
2633 | 2644 | config_normalization=self.config_normalization, |
| 2645 | + config_model=self.config_model, |
2634 | 2646 | n_forecasts=self.n_forecasts, |
2635 | 2647 | n_lags=self.n_lags, |
2636 | 2648 | max_lags=self.max_lags, |
|
0 commit comments