Skip to content

Commit

Permalink
notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
JHoelli committed Aug 8, 2023
1 parent 94b3258 commit e16c8cb
Show file tree
Hide file tree
Showing 8 changed files with 10,200 additions and 34,205 deletions.
Binary file not shown.
2 changes: 1 addition & 1 deletion TSInterpret/InterpretabilityModels/counterfactual/CF.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def plot_in_one(
"""
if self.mode == "time":
item = item.reshape(item.shape[-1], item.shape[-2])
exp = exp.reshape(item.shape[-1], item.shape[-2])
exp = exp.reshape( exp.shape[-1], exp.shape[-2])
else:
item = item.reshape(item.shape[-2], item.shape[-1])
exp = exp.reshape(item.shape[-2], item.shape[-1])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,6 @@ def explain(
tr, _ = explanation
if tr is None:
print("Run Brute Force as Backup.")
import sys
sys.exit(1)
explanation = self.backup.explain(
x_test, num_features=num_features, to_maximize=to_maximize
)
Expand Down Expand Up @@ -483,8 +481,6 @@ def _get_explanation(self, x_test, to_maximize, num_features):

if not self.silent:
logging.info("Current probas: %s", probas)
print('probas', probas)
print('probas',np.argmax(probas))
if np.argmax(probas) == to_maximize:
current_best = np.max(probas)
if current_best > best_explanation_score:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
def random_hill_climb(problem, max_attempts=10, max_iters=np.inf, restarts=0,
init_state=None, curve=True, random_state=None):
init_state=None, curve=False, random_state=None):

# Set random seed
if isinstance(random_state, int) and random_state > 0:
Expand Down
8 changes: 5 additions & 3 deletions TSInterpret/InterpretabilityModels/counterfactual/COMTECF.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def explain(
([np.array], int): Tuple of Counterfactual and Label. Shape of CF : `mode = time` -> `(time, feat)` or `mode = time` -> `(feat, time)`
"""

org_shape=x.shape
if self.mode != "feat":
x = x.reshape(-1, x.shape[-1], x.shape[-2])
train_x, train_y = self.referenceset
Expand All @@ -107,7 +107,9 @@ def explain(
max_attempts=self.max_attemps,
maxiter=self.max_iter,
)
return opt.explain(x, to_maximize=target)
exp,label= opt.explain(x, to_maximize=target)
elif self.method == "brute":
opt = BruteForceSearch(self.predict, train_x, train_y, threads=1)
return opt.explain(x, to_maximize=target)
exp,label= opt.explain(x, to_maximize=target)
return exp.reshape(org_shape), label

37 changes: 28 additions & 9 deletions docs/Notebooks/Ates_sklearn.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit e16c8cb

Please sign in to comment.