6
6
from torch .optim import Optimizer
7
7
8
8
import ignite .distributed as idist
9
+ from ignite import __version__
9
10
from ignite .contrib .handlers .base_logger import (
10
11
BaseLogger ,
11
12
BaseOptimizerParamsHandler ,
26
27
"global_step_from_engine" ,
27
28
]
28
29
30
+ _INTEGRATION_VERSION_KEY = "source_code/integrations/neptune-pytorch-ignite"
31
+
29
32
30
33
class NeptuneLogger (BaseLogger ):
31
34
"""
@@ -42,24 +45,7 @@ class NeptuneLogger(BaseLogger):
42
45
"namespace/project_name" for example "tom/minst-classification".
43
46
If None, the value of NEPTUNE_PROJECT environment variable will be taken.
44
47
You need to create the project in https://neptune.ai first.
45
- offline_mode: Optional default False. If offline_mode=True no logs will be send to neptune.
46
- Usually used for debug purposes.
47
- experiment_name: Optional. Editable name of the experiment.
48
- Name is displayed in the experiment’s Details (Metadata section) and in experiments view as a column.
49
- upload_source_files: Optional. List of source files to be uploaded.
50
- Must be list of str or single str. Uploaded sources are displayed in the experiment’s Source code tab.
51
- If None is passed, Python file from which experiment was created will be uploaded.
52
- Pass empty list (`[]`) to upload no files. Unix style pathname pattern expansion is supported.
53
- For example, you can pass `*.py` to upload all python source files from the current directory.
54
- For recursion lookup use `**/*.py` (for Python 3.5 and later). For more information see glob library.
55
- params: Optional. Parameters of the experiment. After experiment creation params are read-only.
56
- Parameters are displayed in the experiment’s Parameters section and each key-value pair can be
57
- viewed in experiments view as a column.
58
- properties: Optional default is `{}`. Properties of the experiment.
59
- They are editable after experiment is created. Properties are displayed in the experiment’s Details and
60
- each key-value pair can be viewed in experiments view as a column.
61
- tags: Optional default `[]`. Must be list of str. Tags of the experiment.
62
- Tags are displayed in the experiment’s Details and can be viewed in experiments view as a column.
48
+ **kwargs: Other arguments to be passed to Neptune's `init_run`.
63
49
64
50
Examples:
65
51
.. code-block:: python
@@ -71,8 +57,8 @@ class NeptuneLogger(BaseLogger):
71
57
72
58
npt_logger = NeptuneLogger(
73
59
api_token="ANONYMOUS",
74
- project_name ="shared/pytorch-ignite-integration",
75
- experiment_name ="cnn-mnist", # Optional,
60
+ project ="shared/pytorch-ignite-integration",
61
+ name ="cnn-mnist", # Optional,
76
62
params={"max_epochs": 10}, # Optional,
77
63
tags=["pytorch-ignite","minst"] # Optional
78
64
)
@@ -153,8 +139,8 @@ def score_function(engine):
153
139
# We are using the api_token for the anonymous user neptuner but you can use your own.
154
140
155
141
with NeptuneLogger(api_token="ANONYMOUS",
156
- project_name ="shared/pytorch-ignite-integration",
157
- experiment_name ="cnn-mnist", # Optional,
142
+ project ="shared/pytorch-ignite-integration",
143
+ name ="cnn-mnist", # Optional,
158
144
params={"max_epochs": 10}, # Optional,
159
145
tags=["pytorch-ignite","mnist"] # Optional
160
146
) as npt_logger:
@@ -171,39 +157,44 @@ def score_function(engine):
171
157
"""
172
158
173
159
def __getattr__ (self , attr : Any ) -> Any :
160
+ return getattr (self .experiment , attr )
174
161
175
- import neptune
162
+ def __getitem__ (self , key : str ) -> Any :
163
+ return self .experiment [key ]
176
164
177
- return getattr (neptune , attr )
165
+ def __setitem__ (self , key : str , val : Any ) -> Any :
166
+ self .experiment [key ] = val
167
+
168
+ def __init__ (self , api_token : Optional [str ] = None , project : Optional [str ] = None , ** kwargs : Any ) -> None :
169
+ import warnings
178
170
179
- def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
180
171
try :
181
- import neptune
172
+ try :
173
+ # neptune-client<1.0.0 package structure
174
+ with warnings .catch_warnings ():
175
+ # ignore the deprecation warnings
176
+ warnings .simplefilter ("ignore" )
177
+ import neptune .new as neptune
178
+ except ImportError :
179
+ # neptune>=1.0.0 package structure
180
+ import neptune
182
181
except ImportError :
183
182
raise ModuleNotFoundError (
184
- "This contrib module requires neptune-client to be installed. "
185
- "You may install neptune with command: \n pip install neptune-client \n "
186
- )
187
-
188
- if kwargs .get ("offline_mode" , False ):
189
- self .mode = "offline"
190
- neptune .init (
191
- project_qualified_name = "dry-run/project" ,
192
- backend = neptune .OfflineBackend (),
183
+ "This contrib module requires neptune client to be installed. "
184
+ "You may install neptune with command: \n pip install neptune \n "
193
185
)
194
- else :
195
- self .mode = "online"
196
- neptune .init (api_token = kwargs .get ("api_token" ), project_qualified_name = kwargs .get ("project_name" ))
197
186
198
- kwargs ["name" ] = kwargs .pop ("experiment_name" , None )
199
- self ._experiment_kwargs = {
200
- k : v for k , v in kwargs .items () if k not in ["api_token" , "project_name" , "offline_mode" ]
201
- }
187
+ run = neptune .init_run (
188
+ api_token = api_token ,
189
+ project = project ,
190
+ ** kwargs ,
191
+ )
192
+ run [_INTEGRATION_VERSION_KEY ] = __version__
202
193
203
- self .experiment = neptune . create_experiment ( ** self . _experiment_kwargs )
194
+ self .experiment = run
204
195
205
196
def close (self ) -> None :
206
- self .stop ()
197
+ self .experiment . stop ()
207
198
208
199
def _create_output_handler (self , * args : Any , ** kwargs : Any ) -> "OutputHandler" :
209
200
return OutputHandler (* args , ** kwargs )
@@ -355,7 +346,7 @@ def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str,
355
346
)
356
347
357
348
for key , value in metrics .items ():
358
- logger . log_metric ( key , x = global_step , y = value )
349
+ logger [ key ]. append ( value , step = global_step )
359
350
360
351
361
352
class OptimizerParamsHandler (BaseOptimizerParamsHandler ):
@@ -412,7 +403,7 @@ def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str,
412
403
}
413
404
414
405
for k , v in params .items ():
415
- logger . log_metric ( k , x = global_step , y = v )
406
+ logger [ k ]. append ( v , step = global_step )
416
407
417
408
418
409
class WeightsScalarHandler (BaseWeightsScalarHandler ):
@@ -515,11 +506,8 @@ def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str,
515
506
continue
516
507
517
508
name = name .replace ("." , "/" )
518
- logger .log_metric (
519
- f"{ tag_prefix } weights_{ self .reduction .__name__ } /{ name } " ,
520
- x = global_step ,
521
- y = self .reduction (p .data ),
522
- )
509
+ key = f"{ tag_prefix } weights_{ self .reduction .__name__ } /{ name } "
510
+ logger [key ].append (self .reduction (p .data ), step = global_step )
523
511
524
512
525
513
class GradsScalarHandler (BaseWeightsScalarHandler ):
@@ -622,9 +610,8 @@ def __call__(self, engine: Engine, logger: NeptuneLogger, event_name: Union[str,
622
610
continue
623
611
624
612
name = name .replace ("." , "/" )
625
- logger .log_metric (
626
- f"{ tag_prefix } grads_{ self .reduction .__name__ } /{ name } " , x = global_step , y = self .reduction (p .grad )
627
- )
613
+ key = f"{ tag_prefix } grads_{ self .reduction .__name__ } /{ name } "
614
+ logger [key ].append (self .reduction (p .grad ), step = global_step )
628
615
629
616
630
617
class NeptuneSaver (BaseSaveHandler ):
@@ -693,8 +680,8 @@ def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mappin
693
680
# we can not use tmp.name to open tmp.file twice on Win32
694
681
# https://docs.python.org/3/library/tempfile.html#tempfile.NamedTemporaryFile
695
682
torch .save (checkpoint , tmp .file )
696
- self ._logger . log_artifact (tmp .name , filename )
683
+ self ._logger [ filename ]. upload (tmp .name )
697
684
698
685
@idist .one_rank_only (with_barrier = True )
699
686
def remove (self , filename : str ) -> None :
700
- self ._logger .delete_artifacts (filename )
687
+ self ._logger .delete_files (filename )
0 commit comments