14
14
import itertools
15
15
import os
16
16
import shutil
17
- from typing import (Any , Dict , Generator , Iterable , List , Literal , Optional ,
18
- Type , Union )
17
+ from typing import Any , Dict , Iterable , Iterator , List , Literal , Type , Union
19
18
20
19
from typing_extensions import deprecated
21
20
22
21
from ..base import Meta , Traceable , TraceableOnDisk
22
+ from ..data import T
23
23
from ..version import __version__
24
24
from .model import Model
25
25
from .model_line import ModelLine
@@ -55,7 +55,7 @@ def __len__(self) -> int:
55
55
"""
56
56
return len (self ._lines )
57
57
58
- def __iter__ (self ) -> Generator [ ModelLine , None , None ]:
58
+ def __iter__ (self ) -> Iterator [ T ]:
59
59
for line in self ._lines :
60
60
yield self .__getitem__ (line )
61
61
@@ -81,7 +81,9 @@ def reload(self) -> None:
81
81
pass
82
82
83
83
84
- @deprecated ("cascade.models.SingleLineRepo is deprecated, consider using cascade.repos.SingleLineRepo instead" )
84
+ @deprecated (
85
+ "cascade.models.SingleLineRepo is deprecated, consider using cascade.repos.SingleLineRepo instead"
86
+ )
85
87
class SingleLineRepo (Repo ):
86
88
def __init__ (
87
89
self ,
@@ -99,7 +101,9 @@ def __getitem__(self, key: str) -> ModelLine:
99
101
if key in self ._lines :
100
102
return self ._line
101
103
else :
102
- raise KeyError (f"The only line is { list (self ._lines .keys ())[0 ]} , { key } does not exist" )
104
+ raise KeyError (
105
+ f"The only line is { list (self ._lines .keys ())[0 ]} , { key } does not exist"
106
+ )
103
107
104
108
def __repr__ (self ) -> str :
105
109
return f"SingleLine in { self ._root } "
@@ -111,7 +115,9 @@ def reload(self) -> None:
111
115
self ._line .reload ()
112
116
113
117
114
- @deprecated ("cascade.models.ModelRepo is deprecated, consider using cascade.repos.Repo instead" )
118
+ @deprecated (
119
+ "cascade.models.ModelRepo is deprecated, consider using cascade.repos.Repo instead"
120
+ )
115
121
class ModelRepo (Repo , TraceableOnDisk ):
116
122
"""
117
123
An interface to manage experiments with several lines of models.
@@ -306,11 +312,16 @@ def load_model_meta(self, model: str) -> Meta:
306
312
continue
307
313
else :
308
314
return meta
309
- raise FileNotFoundError (f"Failed to find the model { model } in the repo at { self ._root } " )
315
+ raise FileNotFoundError (
316
+ f"Failed to find the model { model } in the repo at { self ._root } "
317
+ )
310
318
311
319
def _update_lines (self ) -> None :
312
320
for name in sorted (os .listdir (self ._root )):
313
- if os .path .isdir (os .path .join (self ._root , name )) and name not in self ._lines :
321
+ if (
322
+ os .path .isdir (os .path .join (self ._root , name ))
323
+ and name not in self ._lines
324
+ ):
314
325
self ._lines [name ] = {"args" : [], "kwargs" : dict ()}
315
326
316
327
@@ -319,7 +330,7 @@ def _update_lines(self) -> None:
319
330
" 0.14.0 and will be removed by 0.15.0"
320
331
" Use Workspaces instead" ,
321
332
category = DeprecationWarning ,
322
- stacklevel = 1
333
+ stacklevel = 1 ,
323
334
)
324
335
class ModelRepoConcatenator (Repo ):
325
336
"""
@@ -350,7 +361,7 @@ def __getitem__(self, key) -> ModelLine:
350
361
def __len__ (self ) -> int :
351
362
return sum ([len (repo ) for repo in self ._repos ])
352
363
353
- def __iter__ (self ) -> Generator [ ModelLine , None , None ]:
364
+ def __iter__ (self ) -> Iterator [ T ]:
354
365
# this flattens the list of lines
355
366
for line in itertools .chain (* [[line for line in repo ] for repo in self ._repos ]):
356
367
yield line
0 commit comments