Skip to content

Commit

Permalink
Fix electricity forecasting tutorial (Lightning-AI#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Dec 16, 2021
1 parent 23e6a0b commit c9b61fa
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
1 change: 0 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ Lightning-Sandbox documentation
:caption: Start here
:glob:

notebooks/*
notebooks/**/*

.. raw:: html
Expand Down
5 changes: 3 additions & 2 deletions flash_tutorials/electricity_forecasting/.meta.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
title: Electricity Price Forecasting with N-BEATS
author: Ethan Harris (ethan@pytorchlightning.ai)
created: 2021-11-23
updated: 2021-11-23
updated: 2021-12-16
license: CC BY-SA
build: 3
tags:
- Tabular
- Forecasting
- Timeseries
description: |
This tutorial covers using Lightning Flash and it's integration with PyTorch Forecasting to train an autoregressive
model (N-BEATS) on hourly electricity pricing data. We show how the built-in interpretability tools from PyTorch
Expand All @@ -15,7 +16,7 @@ description: |
bonus, we show hat we can resample daily observations from the data to discover weekly trends instead.
requirements:
- pandas==1.1.5
- lightning-flash[tabular]>=0.5.2
- lightning-flash[tabular]>=0.6.0
accelerator:
- GPU
- CPU
15 changes: 11 additions & 4 deletions flash_tutorials/electricity_forecasting/electricity_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# %%

import os
from typing import Any, Dict

import flash
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -196,9 +197,15 @@ def preprocess(df: pd.DataFrame, frequency: str = "1H") -> pd.DataFrame:
# %%


def plot_interpretation(model_path: str, predict_df: pd.DataFrame):
def plot_interpretation(model_path: str, predict_df: pd.DataFrame, parameters: Dict[str, Any]):
model = TabularForecaster.load_from_checkpoint(model_path)
predictions = model.predict(predict_df)
datamodule = TabularForecastingData.from_data_frame(
parameters=parameters,
predict_data_frame=predict_df,
batch_size=256,
)
trainer = flash.Trainer(gpus=int(torch.cuda.is_available()))
predictions = trainer.predict(model, datamodule=datamodule)
predictions, inputs = convert_predictions(predictions)
model.pytorch_forecasting_model.plot_interpretation(inputs, predictions, idx=0)
plt.show()
Expand All @@ -208,7 +215,7 @@ def plot_interpretation(model_path: str, predict_df: pd.DataFrame):
# And now we run the function to plot the trend and seasonality curves:

# %%
plot_interpretation(trainer.checkpoint_callback.best_model_path, df_energy_hourly)
plot_interpretation(trainer.checkpoint_callback.best_model_path, df_energy_hourly, datamodule.parameters)

# %% [markdown]
# It worked! The plot shows that the `TabularForecaster` does a reasonable job of modelling the time series and also
Expand Down Expand Up @@ -281,7 +288,7 @@ def plot_interpretation(model_path: str, predict_df: pd.DataFrame):
# Now let's look at what it learned:

# %%
plot_interpretation(trainer.checkpoint_callback.best_model_path, df_energy_daily)
plot_interpretation(trainer.checkpoint_callback.best_model_path, df_energy_daily, datamodule.parameters)

# %% [markdown]
# Success! We can now also see weekly trends / seasonality uncovered by our new model.
Expand Down

0 comments on commit c9b61fa

Please sign in to comment.