Skip to content

Commit a0e8a80

Browse files
committed
[Test error fix]: Fixed the error caused by flake8
1 parent bea1d3e commit a0e8a80

File tree

3 files changed

+102
-1
lines changed

3 files changed

+102
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,4 @@ tmp/
142142
# private
143143
grep.py
144144
memo.txt
145+
private_test.py

autoPyTorch/datasets/base_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ def __getitem__(self, index: int, train: bool = True) -> Tuple[np.ndarray, ...]:
184184

185185
return X, Y
186186

187-
def __len__(self) -> int: return self.train_tensors[0].shape[0]
187+
def __len__(self) -> int:
188+
return self.train_tensors[0].shape[0]
188189

189190
def _get_indices(self) -> np.ndarray:
190191
return self.rng.permutation(len(self)) if self.shuffle else np.arange(len(self))

private_test.ipynb

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 4,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"data": {
10+
"text/plain": [
11+
"[(array([7, 6, 2, 5, 9, 1, 0, 8, 3]), array([4])),\n",
12+
" (array([1, 2, 5, 4, 6, 9, 0, 8, 7]), array([3])),\n",
13+
" (array([7, 0, 9, 3, 4, 6, 5, 1, 8]), array([2]))]"
14+
]
15+
},
16+
"execution_count": 4,
17+
"metadata": {},
18+
"output_type": "execute_result"
19+
}
20+
],
21+
"source": [
22+
"from typing import Protocol, List, Tuple, Dict\n",
23+
"import numpy as np\n",
24+
"from sklearn.model_selection import ShuffleSplit\n",
25+
"\n",
26+
"\n",
27+
"\"\"\"TEST 1 (The original code)\"\"\"\n",
28+
"class CROSS_VAL_FN(Protocol):\n",
29+
" def __call__(self,\n",
30+
" num_splits: int,\n",
31+
" indices: np.ndarray) -> List[Tuple[np.ndarray, np.ndarray]]:\n",
32+
" ...\n",
33+
" \n",
34+
"\n",
35+
"def cv1(num_splits: int, indices: np.ndarray) -> List[Tuple[np.ndarray, np.\n",
36+
" ndarray]]: \n",
37+
" cv = ShuffleSplit(n_splits=num_splits) \n",
38+
" splits = list(cv.split(indices)) \n",
39+
" return splits\n",
40+
"\n",
41+
"\n",
42+
"cvs: Dict[str, CROSS_VAL_FN] = {\"cv1\": cv1}\n",
43+
"cvs[\"cv1\"](3, np.arange(10))"
44+
]
45+
},
46+
{
47+
"cell_type": "code",
48+
"execution_count": 6,
49+
"metadata": {},
50+
"outputs": [
51+
{
52+
"data": {
53+
"text/plain": [
54+
"[(array([6, 0, 3, 7, 5, 9, 4, 1, 8]), array([2])),\n",
55+
" (array([3, 9, 0, 4, 1, 7, 5, 2, 6]), array([8])),\n",
56+
" (array([3, 8, 1, 4, 9, 5, 2, 7, 6]), array([0]))]"
57+
]
58+
},
59+
"execution_count": 6,
60+
"metadata": {},
61+
"output_type": "execute_result"
62+
}
63+
],
64+
"source": []
65+
},
66+
{
67+
"cell_type": "code",
68+
"execution_count": null,
69+
"metadata": {},
70+
"outputs": [],
71+
"source": [
72+
"\"\"\"TEST 3 (expected usage of Protocol)\"\"\"\n",
73+
"def cv_func(cv: CROSS_VAL_FN, **kwargs):\n",
74+
" cv(**kwargs)"
75+
]
76+
}
77+
],
78+
"metadata": {
79+
"kernelspec": {
80+
"display_name": "Python 3",
81+
"language": "python",
82+
"name": "python3"
83+
},
84+
"language_info": {
85+
"codemirror_mode": {
86+
"name": "ipython",
87+
"version": 3
88+
},
89+
"file_extension": ".py",
90+
"mimetype": "text/x-python",
91+
"name": "python",
92+
"nbconvert_exporter": "python",
93+
"pygments_lexer": "ipython3",
94+
"version": "3.8.3"
95+
}
96+
},
97+
"nbformat": 4,
98+
"nbformat_minor": 4
99+
}

0 commit comments

Comments
 (0)