Skip to content

Commit

Permalink
Add __getitem__ method to select train/test indices
Browse files Browse the repository at this point in the history
  • Loading branch information
BirkhoffG committed Nov 8, 2023
1 parent dca3e3e commit 5134914
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 1 deletion.
14 changes: 14 additions & 0 deletions nbs/03_explain.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@
" return f\"Explanation(data_name={self.data_name}, cf_name={self.cf_name}, \" \\\n",
" f\"total_time={self.total_time}, xs={self.xs}, ys={self.ys}, cfs={self.cfs})\"\n",
"\n",
" def __getitem__(self, name: Literal['train', 'val', 'test']) -> Dict[str, Array]:\n",
" if name == 'train':\n",
" indices = self.train_indices\n",
" elif name in ['val', 'test']:\n",
" indices = self.test_indices\n",
" else:\n",
" raise ValueError(f\"Unknown data name: {name}. Should be one of ['train', 'val', 'test']\")\n",
"\n",
" return {\n",
" 'xs': self.xs[indices],\n",
" 'ys': self.ys[indices],\n",
" 'cfs': self.cfs[indices],\n",
" }\n",
" \n",
" @property\n",
" def data(self):\n",
" return self._data\n",
Expand Down
1 change: 1 addition & 0 deletions relax/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@
'relax.evaluate.l2_ann': ('evaluate.html#l2_ann', 'relax/evaluate.py'),
'relax.evaluate.pairwise_distances': ('evaluate.html#pairwise_distances', 'relax/evaluate.py')},
'relax.explain': { 'relax.explain.Explanation': ('explain.html#explanation', 'relax/explain.py'),
'relax.explain.Explanation.__getitem__': ('explain.html#explanation.__getitem__', 'relax/explain.py'),
'relax.explain.Explanation.__init__': ('explain.html#explanation.__init__', 'relax/explain.py'),
'relax.explain.Explanation.__repr__': ('explain.html#explanation.__repr__', 'relax/explain.py'),
'relax.explain.Explanation.apply_constraints': ( 'explain.html#explanation.apply_constraints',
Expand Down
14 changes: 14 additions & 0 deletions relax/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ def __repr__(self):
return f"Explanation(data_name={self.data_name}, cf_name={self.cf_name}, " \
f"total_time={self.total_time}, xs={self.xs}, ys={self.ys}, cfs={self.cfs})"

def __getitem__(self, name: Literal['train', 'val', 'test']) -> Dict[str, Array]:
if name == 'train':
indices = self.train_indices
elif name in ['val', 'test']:
indices = self.test_indices
else:
raise ValueError(f"Unknown data name: {name}. Should be one of ['train', 'val', 'test']")

return {
'xs': self.xs[indices],
'ys': self.ys[indices],
'cfs': self.cfs[indices],
}

@property
def data(self):
return self._data
Expand Down
2 changes: 1 addition & 1 deletion relax/import_essentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# https://github.com/fastai/fastai/blob/master/fastai/imports.py
from __future__ import annotations
import matplotlib.pyplot as plt,numpy as np,pandas as pd,scipy
from typing import Union,Optional,Dict,List,Tuple,Sequence,Mapping,Callable,Iterable,Any,NamedTuple
from typing import Union,Optional,Dict,List,Tuple,Sequence,Mapping,Callable,Iterable,Any,NamedTuple,Literal
import io,operator,sys,os,re,mimetypes,csv,itertools,json,shutil,glob,pickle,tarfile,collections
import hashlib,itertools,types,inspect,functools,time,math,bz2,typing,numbers,string
import multiprocessing,threading,urllib,tempfile,concurrent.futures,matplotlib,warnings,zipfile
Expand Down

0 comments on commit 5134914

Please sign in to comment.