1717import time
1818from typing import TYPE_CHECKING
1919from typing import Dict
20+ from typing import Optional
2021from typing import Tuple
2122from typing import Union
2223
@@ -59,7 +60,7 @@ def _get_dataset_length(
5960
6061
6162def _eval_by_dataset (
62- solver : "solver.Solver" , epoch_id : int , log_freq : int
63+ solver : "solver.Solver" , epoch_id : Optional [ int ] , log_freq : int
6364) -> Tuple [float , Dict [str , Dict [str , float ]]]:
6465 """Evaluate with computing metric on total samples(default process).
6566
@@ -68,7 +69,7 @@ def _eval_by_dataset(
6869
6970 Args:
7071 solver (solver.Solver): Main Solver.
71- epoch_id (int): Epoch id.
72+ epoch_id (Optional[ int] ): Epoch id.
7273 log_freq (int): Log evaluation information every `log_freq` steps.
7374
7475 Returns:
@@ -189,7 +190,7 @@ def _eval_by_dataset(
189190
190191
191192def _eval_by_batch (
192- solver : "solver.Solver" , epoch_id : int , log_freq : int
193+ solver : "solver.Solver" , epoch_id : Optional [ int ] , log_freq : int
193194) -> Tuple [float , Dict [str , Dict [str , float ]]]:
194195 """Evaluate with computing metric by batch, which is memory-efficient.
195196
@@ -199,7 +200,7 @@ def _eval_by_batch(
199200
200201 Args:
201202 solver (solver.Solver): Main Solver.
202- epoch_id (int): Epoch id.
203+ epoch_id (Optional[ int] ): Epoch id.
203204 log_freq (int): Log evaluation information every `log_freq` steps.
204205
205206 Returns:
@@ -303,13 +304,13 @@ def _eval_by_batch(
303304
304305
305306def eval_func (
306- solver : "solver.Solver" , epoch_id : int , log_freq : int
307+ solver : "solver.Solver" , epoch_id : Optional [ int ] , log_freq : int
307308) -> Tuple [float , Dict [str , Dict [str , float ]]]:
308309 """Evaluation function.
309310
310311 Args:
311312 solver (solver.Solver): Main Solver.
312- epoch_id (int): Epoch id.
313+ epoch_id (Optional[ int] ): Epoch id.
313314 log_freq (int): Log evaluation information every `log_freq` steps.
314315
315316 Returns:
0 commit comments