Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shap fix #762

Merged
merged 12 commits into from
Jun 24, 2024
2 changes: 1 addition & 1 deletion flood_forecast/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_model_r2_score(
):
"""

model_evaluate_function should call any necessary preprocessing
model_evaluate_function should call any necessary preprocessing.
"""
test_river_data, baseline_mse = stream_baseline(river_flow_df, forecast_column)

Expand Down
15 changes: 7 additions & 8 deletions flood_forecast/explain_model_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np
import shap
import torch

import wandb
from flood_forecast.plot_functions import (
plot_shap_value_heatmaps,
Expand All @@ -27,7 +26,7 @@ def handle_dl_output(dl, dl_class: str, datetime_start: datetime, device: str) -
:type datetime_start: datetime
:param device: Typical device should be either cpu or cuda
:type device: str
:return: Returns a tuple containing either a..
:return: Returns a tuple containing either a list of tensors or a single tensor, and an integer
:rtype: Tuple[torch.Tensor, int]
"""
if dl_class == "TemporalLoader":
Expand Down Expand Up @@ -105,11 +104,11 @@ def deep_explain_model_summary_plot(
if isinstance(history, list):
model.model = model.model.to("cpu")
deep_explainer = shap.DeepExplainer(model.model, history)
shap_values = deep_explainer.shap_values(history)
shap_values = deep_explainer.shap_values(history, check_additivity=False)
s_values_list.append(shap_values)
else:
deep_explainer = shap.DeepExplainer(model.model, background_tensor)
shap_values = deep_explainer.shap_values(background_tensor)
shap_values = deep_explainer.shap_values(background_tensor, check_additivity=False)
shap_values = fix_shap_values(shap_values, history)
shap_values = np.stack(shap_values)
# shap_values needs to be 4-dimensional
Expand Down Expand Up @@ -147,7 +146,7 @@ def deep_explain_model_summary_plot(
hist.cpu().numpy(), names=["batches", "observations", "features"]
)

shap_values = deep_explainer.shap_values(history)
shap_values = deep_explainer.shap_values(history, check_additivity=False)
shap_values = fix_shap_values(shap_values, history)
shap_values = np.stack(shap_values)
if len(shap_values.shape) != 4:
Expand Down Expand Up @@ -216,11 +215,11 @@ def deep_explain_model_heatmap(
s_values_list = []
if isinstance(history, list):
deep_explainer = shap.DeepExplainer(model.model, history)
shap_values = deep_explainer.shap_values(history)
shap_values = deep_explainer.shap_values(history, check_additivity=False)
s_values_list.append(shap_values)
else:
deep_explainer = shap.DeepExplainer(model.model, background_tensor)
shap_values = deep_explainer.shap_values(background_tensor)
shap_values = deep_explainer.shap_values(background_tensor, check_additivity=False)
shap_values = fix_shap_values(shap_values, history)
shap_values = np.stack(shap_values) # forecast_len x N x L x M
if len(shap_values.shape) != 4:
Expand All @@ -236,7 +235,7 @@ def deep_explain_model_heatmap(
# heatmap one prediction sequence at datetime_start
# (seq_len*forecast_len) per fop feature
to_explain = history
shap_values = deep_explainer.shap_values(to_explain)
shap_values = deep_explainer.shap_values(to_explain, check_additivity=False)
shap_values = fix_shap_values(shap_values, history)
shap_values = np.stack(shap_values)
if len(shap_values.shape) != 4:
Expand Down
3 changes: 1 addition & 2 deletions flood_forecast/transformer_xl/transformer_bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@
import torch
import torch.nn as nn
import math
# from torch.distributions.normal import Normal
import copy
from torch.nn.parameter import Parameter
from typing import Dict
from flood_forecast.transformer_xl.lower_upper_config import activation_dict


def gelu(x):
def gelu(x: torch.Tensor):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ google-cloud-storage
plotly~=5.20.0
pytz>=2022.1
setuptools~=69.5.1
numpy>=1.21
numpy==1.26
requests
torchvision>=0.6.0
mpld3>=0.5
Expand Down