Skip to content

Commit

Permalink
Merge pull request #26 from KarelZe/linting-formatter
Browse files Browse the repository at this point in the history
Add basic support for linting / format🐍
  • Loading branch information
iancovert authored Feb 19, 2024
2 parents 049e35e + 1dd7e1d commit 2c98a28
Show file tree
Hide file tree
Showing 13 changed files with 817 additions and 594 deletions.
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@ repos:
- id: debug-statements
- id: end-of-file-fixer
- id: mixed-line-ending
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.1
hooks:
- id: ruff
args:
- --fix
- id: ruff-format
32 changes: 32 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
dev=[
"build",
"pre-commit",
"ruff"
]

notebook=[
Expand All @@ -54,3 +55,34 @@ build-backend = "setuptools.build_meta"

[tool.setuptools]
packages = ["sage"]

[tool.ruff]


include = ["*.py", "*.pyi", "**/pyproject.toml", "*.ipynb"]


[tool.ruff.lint]

# See rules: https://beta.ruff.rs/docs/rules/
select = [
"I", # isort
"N", # pep8-naming
"NPY", # numpy
"RUF", # ruff-specific rules
]

ignore = [
"N803", # argument name should be lowercase; fine for matrices
"N806", # variable name should be lowercase; fine for matrices
"NPY002", # allow calls to np.random; could cause slightly different results
]

preview = true

[tool.ruff.format]
preview = true

[tool.ruff.lint.isort]
known-first-party = ["sage"]
section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]
9 changes: 5 additions & 4 deletions sage/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from sage import utils, core, imputers, grouped_imputers, plotting, datasets
from sage import core, datasets, grouped_imputers, imputers, plotting, utils

from .core import Explanation, load
from .plotting import plot, comparison_plot
from .imputers import DefaultImputer, MarginalImputer
from .grouped_imputers import GroupedDefaultImputer, GroupedMarginalImputer
from .permutation_estimator import PermutationEstimator
from .imputers import DefaultImputer, MarginalImputer
from .iterated_estimator import IteratedEstimator
from .kernel_estimator import KernelEstimator
from .permutation_estimator import PermutationEstimator
from .plotting import comparison_plot, plot
from .sign_estimator import SignEstimator
210 changes: 130 additions & 80 deletions sage/core.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,44 @@
import pickle

import numpy as np

from sage import plotting


class Explanation:
'''
"""
For storing and plotting Explanations.
Args:
values: explanation values.
std: standard deviation confidence intervals for explanation values.
explanation_type: 'SAGE' or 'Shapley Effects' (used only for plotting).
'''
"""

def __init__(self, values, std, explanation_type='SAGE'):
def __init__(self, values, std, explanation_type="SAGE"):
self.values = values
self.std = std
self.explanation_type = explanation_type

def plot(self,
feature_names=None,
sort_features=True,
max_features=np.inf,
orientation='horizontal',
error_bars=True,
confidence_level=0.95,
capsize=5,
color='tab:green',
title='Feature Importance',
title_size=20,
tick_size=16,
tick_rotation=None,
label_size=16,
figsize=(10, 7),
return_fig=False):
'''
def plot(
self,
feature_names=None,
sort_features=True,
max_features=np.inf,
orientation="horizontal",
error_bars=True,
confidence_level=0.95,
capsize=5,
color="tab:green",
title="Feature Importance",
title_size=20,
tick_size=16,
tick_rotation=None,
label_size=16,
figsize=(10, 7),
return_fig=False,
):
"""
Plot SAGE values.
Args:
Expand All @@ -53,32 +57,48 @@ def plot(self,
label_size: font size for label.
figsize: figure size (if fig is None).
return_fig: whether to return matplotlib figure object.
'''
"""
return plotting.plot(
self, feature_names, sort_features, max_features, orientation,
error_bars, confidence_level, capsize, color, title, title_size,
tick_size, tick_rotation, label_size, figsize, return_fig)

def comparison(self,
other_values,
comparison_names=None,
feature_names=None,
sort_features=True,
max_features=np.inf,
orientation='vertical',
error_bars=True,
confidence_level=0.95,
capsize=5,
colors=None,
title='Feature Importance Comparison',
title_size=20,
tick_size=16,
tick_rotation=None,
label_size=16,
legend_loc=None,
figsize=(10, 7),
return_fig=False):
'''
self,
feature_names,
sort_features,
max_features,
orientation,
error_bars,
confidence_level,
capsize,
color,
title,
title_size,
tick_size,
tick_rotation,
label_size,
figsize,
return_fig,
)

def comparison(
self,
other_values,
comparison_names=None,
feature_names=None,
sort_features=True,
max_features=np.inf,
orientation="vertical",
error_bars=True,
confidence_level=0.95,
capsize=5,
colors=None,
title="Feature Importance Comparison",
title_size=20,
tick_size=16,
tick_rotation=None,
label_size=16,
legend_loc=None,
figsize=(10, 7),
return_fig=False,
):
"""
Plot comparison with another set of SAGE values.
Args:
Expand All @@ -100,28 +120,45 @@ def comparison(self,
legend_loc: legend location.
figsize: figure size (if fig is None).
return_fig: whether to return matplotlib figure object.
'''
"""
return plotting.comparison_plot(
(self, other_values), comparison_names, feature_names,
sort_features, max_features, orientation, error_bars,
confidence_level, capsize, colors, title, title_size, tick_size,
tick_rotation, label_size, legend_loc, figsize, return_fig)

def plot_sign(self,
feature_names,
sort_features=True,
max_features=np.inf,
orientation='horizontal',
confidence_level=0.95,
capsize=5,
title='Feature Importance Sign',
title_size=20,
tick_size=16,
tick_rotation=None,
label_size=16,
figsize=(10, 7),
return_fig=False):
'''
(self, other_values),
comparison_names,
feature_names,
sort_features,
max_features,
orientation,
error_bars,
confidence_level,
capsize,
colors,
title,
title_size,
tick_size,
tick_rotation,
label_size,
legend_loc,
figsize,
return_fig,
)

def plot_sign(
self,
feature_names,
sort_features=True,
max_features=np.inf,
orientation="horizontal",
confidence_level=0.95,
capsize=5,
title="Feature Importance Sign",
title_size=20,
tick_size=16,
tick_rotation=None,
label_size=16,
figsize=(10, 7),
return_fig=False,
):
"""
Plot SAGE values, focusing on their sign.
Args:
Expand All @@ -138,31 +175,44 @@ def plot_sign(self,
label_size: font size for label.
figsize: figure size (if fig is None).
return_fig: whether to return matplotlib figure object.
'''
"""
return plotting.plot_sign(
self, feature_names, sort_features, max_features, orientation,
confidence_level, capsize, title, title_size, tick_size,
tick_rotation, label_size, figsize, return_fig)
self,
feature_names,
sort_features,
max_features,
orientation,
confidence_level,
capsize,
title,
title_size,
tick_size,
tick_rotation,
label_size,
figsize,
return_fig,
)

def save(self, filename):
'''Save Explanation object.'''
"""Save Explanation object."""
if isinstance(filename, str):
with open(filename, 'wb') as f:
with open(filename, "wb") as f:
pickle.dump(self, f)
else:
raise TypeError('filename must be str')
raise TypeError("filename must be str")

def __repr__(self):
with np.printoptions(precision=2, threshold=12, floatmode='fixed'):
return '{} Explanation(\n (Mean): {}\n (Std): {}\n)'.format(
self.explanation_type, self.values, self.std)
with np.printoptions(precision=2, threshold=12, floatmode="fixed"):
return "{} Explanation(\n (Mean): {}\n (Std): {}\n)".format(
self.explanation_type, self.values, self.std
)


def load(filename):
'''Load Explanation object.'''
with open(filename, 'rb') as f:
"""Load Explanation object."""
with open(filename, "rb") as f:
sage_values = pickle.load(f)
if isinstance(sage_values, Explanation):
return sage_values
else:
raise ValueError('object is not instance of Explanation class')
raise ValueError("object is not instance of Explanation class")
Loading

0 comments on commit 2c98a28

Please sign in to comment.