Skip to content

Commit

Permalink
[Test error fix]: Fixed the error caused by flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Feb 22, 2021
1 parent bea1d3e commit a0e8a80
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,4 @@ tmp/
# private
grep.py
memo.txt
private_test.py
3 changes: 2 additions & 1 deletion autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ def __getitem__(self, index: int, train: bool = True) -> Tuple[np.ndarray, ...]:

return X, Y

def __len__(self) -> int: return self.train_tensors[0].shape[0]
def __len__(self) -> int:
return self.train_tensors[0].shape[0]

def _get_indices(self) -> np.ndarray:
return self.rng.permutation(len(self)) if self.shuffle else np.arange(len(self))
Expand Down
99 changes: 99 additions & 0 deletions private_test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(array([7, 6, 2, 5, 9, 1, 0, 8, 3]), array([4])),\n",
" (array([1, 2, 5, 4, 6, 9, 0, 8, 7]), array([3])),\n",
" (array([7, 0, 9, 3, 4, 6, 5, 1, 8]), array([2]))]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from typing import Protocol, List, Tuple, Dict\n",
"import numpy as np\n",
"from sklearn.model_selection import ShuffleSplit\n",
"\n",
"\n",
"\"\"\"TEST 1 (The original code)\"\"\"\n",
"class CROSS_VAL_FN(Protocol):\n",
" def __call__(self,\n",
" num_splits: int,\n",
" indices: np.ndarray) -> List[Tuple[np.ndarray, np.ndarray]]:\n",
" ...\n",
" \n",
"\n",
"def cv1(num_splits: int, indices: np.ndarray) -> List[Tuple[np.ndarray, np.\n",
" ndarray]]: \n",
" cv = ShuffleSplit(n_splits=num_splits) \n",
" splits = list(cv.split(indices)) \n",
" return splits\n",
"\n",
"\n",
"cvs: Dict[str, CROSS_VAL_FN] = {\"cv1\": cv1}\n",
"cvs[\"cv1\"](3, np.arange(10))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(array([6, 0, 3, 7, 5, 9, 4, 1, 8]), array([2])),\n",
" (array([3, 9, 0, 4, 1, 7, 5, 2, 6]), array([8])),\n",
" (array([3, 8, 1, 4, 9, 5, 2, 7, 6]), array([0]))]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"TEST 3 (expected usage of Protocol)\"\"\"\n",
"def cv_func(cv: CROSS_VAL_FN, **kwargs):\n",
" cv(**kwargs)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

0 comments on commit a0e8a80

Please sign in to comment.