|  | 
| 10 | 10 | import shutil | 
| 11 | 11 | import tempfile | 
| 12 | 12 | import unittest | 
|  | 13 | +from unittest.mock import patch | 
| 13 | 14 | 
 | 
| 14 | 15 | import torch | 
| 15 | 16 | 
 | 
| 16 | 17 | import torch.distributed as dist | 
| 17 | 18 | from torch import nn | 
| 18 | 19 | from torchsnapshot import Snapshot | 
| 19 | 20 | from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME | 
|  | 21 | +from torchtnt.framework._test_utils import Batch | 
|  | 22 | +from torchtnt.framework.state import State | 
|  | 23 | +from torchtnt.framework.unit import TrainUnit | 
| 20 | 24 | from torchtnt.utils import get_global_rank, init_from_env | 
| 21 | 25 | 
 | 
| 22 | 26 | from torchtnt.utils.checkpoint import ( | 
|  | 
| 25 | 29 |     _retrieve_checkpoint_dirpaths, | 
| 26 | 30 |     _sort_by_metric_value, | 
| 27 | 31 |     _sort_by_recency, | 
|  | 32 | +    BestCheckpointConfig, | 
|  | 33 | +    CheckpointManager, | 
| 28 | 34 |     CheckpointPath, | 
| 29 | 35 |     get_best_checkpoint_path, | 
| 30 | 36 |     get_checkpoint_dirpaths, | 
| @@ -190,6 +196,349 @@ def test_pickling(self) -> None: | 
| 190 | 196 |             self.assertEqual(unpickled, ckpt) | 
| 191 | 197 | 
 | 
| 192 | 198 | 
 | 
|  | 199 | +class CheckpointManagerTest(unittest.TestCase): | 
|  | 200 | +    def test_create_checkpoint_manager(self) -> None: | 
|  | 201 | +        with tempfile.TemporaryDirectory() as temp_dir: | 
|  | 202 | +            paths = [ | 
|  | 203 | +                f"{temp_dir}/epoch_1_step_3", | 
|  | 204 | +                f"{temp_dir}/epoch_0_step_1", | 
|  | 205 | +                f"{temp_dir}/epoch_0_step_5_loss=-0.3", | 
|  | 206 | +                f"{temp_dir}/epoch_1_step_1", | 
|  | 207 | +                f"{temp_dir}/epoch_1_step_2_loss=0.5", | 
|  | 208 | +                f"{temp_dir}/epoch_2_step_5_loss=0.3", | 
|  | 209 | +                f"{temp_dir}/epoch_0_step_2_acc=0.7", | 
|  | 210 | +            ] | 
|  | 211 | +            for path in paths: | 
|  | 212 | +                os.mkdir(path) | 
|  | 213 | + | 
|  | 214 | +            # without last_n_checkpoints | 
|  | 215 | +            ckpt_manager = CheckpointManager(temp_dir) | 
|  | 216 | +            self.assertEqual(ckpt_manager._ckpt_paths, []) | 
|  | 217 | + | 
|  | 218 | +            # with last_n_checkpoints but without metric | 
|  | 219 | +            ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=2) | 
|  | 220 | +            self.assertEqual( | 
|  | 221 | +                [x.path for x in ckpt_manager._ckpt_paths], | 
|  | 222 | +                [ | 
|  | 223 | +                    f"{temp_dir}/epoch_0_step_1", | 
|  | 224 | +                    f"{temp_dir}/epoch_0_step_2_acc=0.7", | 
|  | 225 | +                    f"{temp_dir}/epoch_0_step_5_loss=-0.3", | 
|  | 226 | +                    f"{temp_dir}/epoch_1_step_1", | 
|  | 227 | +                    f"{temp_dir}/epoch_1_step_2_loss=0.5", | 
|  | 228 | +                    f"{temp_dir}/epoch_1_step_3", | 
|  | 229 | +                    f"{temp_dir}/epoch_2_step_5_loss=0.3", | 
|  | 230 | +                ], | 
|  | 231 | +            ) | 
|  | 232 | + | 
|  | 233 | +            # with last_n_checkpoints and metric min | 
|  | 234 | +            ckpt_manager = CheckpointManager( | 
|  | 235 | +                temp_dir, | 
|  | 236 | +                keep_last_n_checkpoints=3, | 
|  | 237 | +                best_checkpoint_config=BestCheckpointConfig( | 
|  | 238 | +                    monitored_metric="loss", mode="min" | 
|  | 239 | +                ), | 
|  | 240 | +            ) | 
|  | 241 | +            self.assertEqual( | 
|  | 242 | +                [x.path for x in ckpt_manager._ckpt_paths], | 
|  | 243 | +                [ | 
|  | 244 | +                    f"{temp_dir}/epoch_1_step_2_loss=0.5", | 
|  | 245 | +                    f"{temp_dir}/epoch_2_step_5_loss=0.3", | 
|  | 246 | +                    f"{temp_dir}/epoch_0_step_5_loss=-0.3", | 
|  | 247 | +                ], | 
|  | 248 | +            ) | 
|  | 249 | + | 
|  | 250 | +            # with last_n_checkpoints and metric max | 
|  | 251 | +            ckpt_manager = CheckpointManager( | 
|  | 252 | +                temp_dir, | 
|  | 253 | +                keep_last_n_checkpoints=3, | 
|  | 254 | +                best_checkpoint_config=BestCheckpointConfig( | 
|  | 255 | +                    monitored_metric="loss", mode="max" | 
|  | 256 | +                ), | 
|  | 257 | +            ) | 
|  | 258 | +            self.assertEqual( | 
|  | 259 | +                [x.path for x in ckpt_manager._ckpt_paths], | 
|  | 260 | +                [ | 
|  | 261 | +                    f"{temp_dir}/epoch_0_step_5_loss=-0.3", | 
|  | 262 | +                    f"{temp_dir}/epoch_2_step_5_loss=0.3", | 
|  | 263 | +                    f"{temp_dir}/epoch_1_step_2_loss=0.5", | 
|  | 264 | +                ], | 
|  | 265 | +            ) | 
|  | 266 | + | 
|  | 267 | +            # with last_n_checkpoints and non previously tracked metric | 
|  | 268 | +            ckpt_manager = CheckpointManager( | 
|  | 269 | +                temp_dir, | 
|  | 270 | +                keep_last_n_checkpoints=3, | 
|  | 271 | +                best_checkpoint_config=BestCheckpointConfig( | 
|  | 272 | +                    monitored_metric="foo", mode="max" | 
|  | 273 | +                ), | 
|  | 274 | +            ) | 
|  | 275 | +            self.assertEqual(ckpt_manager._ckpt_paths, []) | 
|  | 276 | + | 
|  | 277 | +    @skip_if_not_distributed | 
|  | 278 | +    def test_create_checkpoint_manager_distributed(self) -> None: | 
|  | 279 | +        spawn_multi_process( | 
|  | 280 | +            2, | 
|  | 281 | +            "gloo", | 
|  | 282 | +            self._test_create_checkpoint_manager_distributed, | 
|  | 283 | +        ) | 
|  | 284 | + | 
|  | 285 | +    @staticmethod | 
|  | 286 | +    def _test_create_checkpoint_manager_distributed() -> None: | 
|  | 287 | +        if get_global_rank() == 0: | 
|  | 288 | +            temp_dir = tempfile.mkdtemp() | 
|  | 289 | +            paths = ["epoch_1_step_2", "epoch_0_step_1", "epoch_1_step_1"] | 
|  | 290 | +            for path in paths: | 
|  | 291 | +                os.mkdir(os.path.join(temp_dir, path)) | 
|  | 292 | +        else: | 
|  | 293 | +            temp_dir = "" | 
|  | 294 | + | 
|  | 295 | +        tc = unittest.TestCase() | 
|  | 296 | + | 
|  | 297 | +        # without top k config | 
|  | 298 | +        ckpt_manager = CheckpointManager(temp_dir) | 
|  | 299 | +        tc.assertNotEqual(ckpt_manager.dirpath, "") | 
|  | 300 | +        tc.assertEqual(ckpt_manager._ckpt_paths, []) | 
|  | 301 | + | 
|  | 302 | +        # with top k config | 
|  | 303 | +        ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=1) | 
|  | 304 | +        tc.assertNotEqual(ckpt_manager.dirpath, "") | 
|  | 305 | +        tc.assertEqual( | 
|  | 306 | +            [str(x) for x in ckpt_manager._ckpt_paths], | 
|  | 307 | +            [ | 
|  | 308 | +                os.path.join(ckpt_manager.dirpath, path) | 
|  | 309 | +                for path in [ | 
|  | 310 | +                    "epoch_0_step_1", | 
|  | 311 | +                    "epoch_1_step_1", | 
|  | 312 | +                    "epoch_1_step_2", | 
|  | 313 | +                ] | 
|  | 314 | +            ], | 
|  | 315 | +        ) | 
|  | 316 | + | 
|  | 317 | +    def test_prune_surplus_checkpoints(self) -> None: | 
|  | 318 | +        # with checkpoints to delete | 
|  | 319 | +        with tempfile.TemporaryDirectory() as temp_dir: | 
|  | 320 | +            ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=1) | 
|  | 321 | +            paths = [ | 
|  | 322 | +                CheckpointPath(temp_dir, 0, 0), | 
|  | 323 | +                CheckpointPath(temp_dir, 0, 1), | 
|  | 324 | +                CheckpointPath(temp_dir, 1, 0), | 
|  | 325 | +            ] | 
|  | 326 | +            for path in paths: | 
|  | 327 | +                os.mkdir(path.path) | 
|  | 328 | + | 
|  | 329 | +            ckpt_manager._ckpt_paths = list(paths) | 
|  | 330 | +            warning_messages = [] | 
|  | 331 | +            expected_warning_msg = ( | 
|  | 332 | +                f"3 checkpoints found in {temp_dir}. ", | 
|  | 333 | +                f"Deleting {2} oldest ", | 
|  | 334 | +                "checkpoints to enforce ``keep_last_n_checkpoints`` argument.", | 
|  | 335 | +            ) | 
|  | 336 | +            with patch( | 
|  | 337 | +                f"{CheckpointManager.__module__}.logging.Logger.warning", | 
|  | 338 | +                warning_messages.append, | 
|  | 339 | +            ): | 
|  | 340 | +                ckpt_manager.prune_surplus_checkpoints() | 
|  | 341 | + | 
|  | 342 | +            self.assertEqual(warning_messages[0], expected_warning_msg) | 
|  | 343 | +            self.assertEqual(ckpt_manager._ckpt_paths, [paths[2]]) | 
|  | 344 | +            self.assertTrue(os.path.exists(paths[2].path)) | 
|  | 345 | +            self.assertFalse(os.path.exists(paths[0].path)) | 
|  | 346 | +            self.assertFalse(os.path.exists(paths[1].path)) | 
|  | 347 | + | 
|  | 348 | +        # without checkpoints to delete | 
|  | 349 | +        with tempfile.TemporaryDirectory() as temp_dir: | 
|  | 350 | +            ckpt_manager = CheckpointManager(temp_dir) | 
|  | 351 | +            paths = [ | 
|  | 352 | +                CheckpointPath(temp_dir, 0, 0), | 
|  | 353 | +                CheckpointPath(temp_dir, 0, 1), | 
|  | 354 | +                CheckpointPath(temp_dir, 1, 0), | 
|  | 355 | +            ] | 
|  | 356 | +            ckpt_manager._ckpt_paths = list(paths) | 
|  | 357 | +            ckpt_manager.prune_surplus_checkpoints() | 
|  | 358 | +            self.assertEqual(ckpt_manager._ckpt_paths, paths) | 
|  | 359 | + | 
|  | 360 | +    def test_generate_checkpoint_path(self) -> None: | 
|  | 361 | +        ckpt_manager = CheckpointManager("foo") | 
|  | 362 | + | 
|  | 363 | +        self.assertEqual( | 
|  | 364 | +            ckpt_manager.generate_checkpoint_path(1, 1).path, | 
|  | 365 | +            "foo/epoch_1_step_1", | 
|  | 366 | +        ) | 
|  | 367 | + | 
|  | 368 | +        self.assertEqual( | 
|  | 369 | +            ckpt_manager.generate_checkpoint_path(1, 3).path, | 
|  | 370 | +            "foo/epoch_1_step_3", | 
|  | 371 | +        ) | 
|  | 372 | + | 
|  | 373 | +        ckpt_manager._best_checkpoint_config = BestCheckpointConfig( | 
|  | 374 | +            monitored_metric="val_loss", mode="min" | 
|  | 375 | +        ) | 
|  | 376 | +        self.assertEqual( | 
|  | 377 | +            ckpt_manager.generate_checkpoint_path( | 
|  | 378 | +                1, 3, MetricData("val_loss", 0.5) | 
|  | 379 | +            ).path, | 
|  | 380 | +            "foo/epoch_1_step_3_val_loss=0.5", | 
|  | 381 | +        ) | 
|  | 382 | + | 
|  | 383 | +        # best checkpoint config, but did not pass metric data - expect path but no metric | 
|  | 384 | +        self.assertEqual( | 
|  | 385 | +            ckpt_manager.generate_checkpoint_path(1, 2).path, | 
|  | 386 | +            "foo/epoch_1_step_2", | 
|  | 387 | +        ) | 
|  | 388 | + | 
|  | 389 | +        # passed metric data is tracking a different metric than best checkpoint config - expect exception | 
|  | 390 | +        with self.assertRaisesRegex( | 
|  | 391 | +            AssertionError, | 
|  | 392 | +            "Attempted to get a checkpoint with metric 'mean', but best checkpoint config is for 'val_loss'", | 
|  | 393 | +        ): | 
|  | 394 | +            ckpt_manager.generate_checkpoint_path(1, 2, MetricData("mean", 3.5)) | 
|  | 395 | + | 
|  | 396 | +        # no best checkpoint config, but passed metric data - expect exception | 
|  | 397 | +        ckpt_manager._best_checkpoint_config = None | 
|  | 398 | +        with self.assertRaisesRegex( | 
|  | 399 | +            AssertionError, | 
|  | 400 | +            "Attempted to get a checkpoint with metric but best checkpoint config is not set", | 
|  | 401 | +        ): | 
|  | 402 | +            ckpt_manager.generate_checkpoint_path(1, 2, MetricData("val_loss", 3.5)) | 
|  | 403 | + | 
|  | 404 | +    def test_append_checkpoint_by_recency(self) -> None: | 
|  | 405 | +        ckpt_manager = CheckpointManager("foo", keep_last_n_checkpoints=2) | 
|  | 406 | +        ckpt_manager._ckpt_paths = [CheckpointPath("foo", 0, 0)] | 
|  | 407 | + | 
|  | 408 | +        # without need to remove old by recency | 
|  | 409 | +        ckpt_manager.append_checkpoint(CheckpointPath("foo", 0, 1)) | 
|  | 410 | +        self.assertEqual( | 
|  | 411 | +            ckpt_manager._ckpt_paths, | 
|  | 412 | +            [CheckpointPath("foo", 0, 0), CheckpointPath("foo", 0, 1)], | 
|  | 413 | +        ) | 
|  | 414 | + | 
|  | 415 | +        # removing old by recency | 
|  | 416 | +        with patch("fsspec.implementations.local.LocalFileSystem.rm") as mock_rm: | 
|  | 417 | +            ckpt_manager.append_checkpoint(CheckpointPath("foo", 0, 2)) | 
|  | 418 | +            self.assertEqual( | 
|  | 419 | +                ckpt_manager._ckpt_paths, | 
|  | 420 | +                [CheckpointPath("foo", 0, 1), CheckpointPath("foo", 0, 2)], | 
|  | 421 | +            ) | 
|  | 422 | +            mock_rm.assert_called_once_with("foo/epoch_0_step_0", recursive=True) | 
|  | 423 | + | 
|  | 424 | +    def test_append_checkpoint_by_metric(self) -> None: | 
|  | 425 | +        ckpt_manager = CheckpointManager( | 
|  | 426 | +            "foo", | 
|  | 427 | +            keep_last_n_checkpoints=5, | 
|  | 428 | +            best_checkpoint_config=BestCheckpointConfig( | 
|  | 429 | +                monitored_metric="val_loss", mode="max" | 
|  | 430 | +            ), | 
|  | 431 | +        ) | 
|  | 432 | +        paths = [ | 
|  | 433 | +            CheckpointPath( | 
|  | 434 | +                "foo", 0, x, metric_data=MetricData(name="val_loss", value=0.01 * x) | 
|  | 435 | +            ) | 
|  | 436 | +            for x in range(1, 7, 1) | 
|  | 437 | +        ] | 
|  | 438 | +        ckpt_manager._ckpt_paths = [paths[1], paths[2], paths[4]] | 
|  | 439 | +        # without need to remove old by min metric, goes beginning | 
|  | 440 | +        ckpt_manager.append_checkpoint(paths[0]) | 
|  | 441 | +        self.assertEqual( | 
|  | 442 | +            ckpt_manager._ckpt_paths, | 
|  | 443 | +            [paths[0], paths[1], paths[2], paths[4]], | 
|  | 444 | +        ) | 
|  | 445 | +        # without need to remove old by min metric, goes end | 
|  | 446 | +        ckpt_manager.append_checkpoint(paths[5]) | 
|  | 447 | +        self.assertEqual( | 
|  | 448 | +            ckpt_manager._ckpt_paths, | 
|  | 449 | +            [paths[0], paths[1], paths[2], paths[4], paths[5]], | 
|  | 450 | +        ) | 
|  | 451 | +        # removing old max metric, goes middle | 
|  | 452 | +        with patch("fsspec.implementations.local.LocalFileSystem.rm") as mock_rm: | 
|  | 453 | +            ckpt_manager.append_checkpoint(paths[3]) | 
|  | 454 | +            self.assertEqual( | 
|  | 455 | +                ckpt_manager._ckpt_paths, | 
|  | 456 | +                [paths[1], paths[2], paths[3], paths[4], paths[5]], | 
|  | 457 | +            ) | 
|  | 458 | +            mock_rm.assert_called_once_with( | 
|  | 459 | +                "foo/epoch_0_step_1_val_loss=0.01", recursive=True | 
|  | 460 | +            ) | 
|  | 461 | + | 
|  | 462 | +        # no metric data - noop | 
|  | 463 | +        ckpt_manager._keep_last_n_checkpoints = None | 
|  | 464 | +        ckpt_manager.append_checkpoint(CheckpointPath("foo", 0, 8)) | 
|  | 465 | +        self.assertEqual( | 
|  | 466 | +            ckpt_manager._ckpt_paths, | 
|  | 467 | +            [paths[1], paths[2], paths[3], paths[4], paths[5]], | 
|  | 468 | +        ) | 
|  | 469 | + | 
|  | 470 | +    def test_should_save_checkpoint(self) -> None: | 
|  | 471 | +        """ | 
|  | 472 | +        Tests basic functionality of should_save_checkpoint | 
|  | 473 | +        """ | 
|  | 474 | +        ckpt_manager = CheckpointManager("foo") | 
|  | 475 | + | 
|  | 476 | +        # test default behavior | 
|  | 477 | +        ckpt = CheckpointPath("foo", 0, 2) | 
|  | 478 | +        self.assertTrue(ckpt_manager.should_save_checkpoint(ckpt)) | 
|  | 479 | + | 
|  | 480 | +        ckpt_manager._ckpt_paths = [CheckpointPath("foo", 0, 1)] | 
|  | 481 | +        self.assertTrue(ckpt_manager.should_save_checkpoint(ckpt)) | 
|  | 482 | +        ckpt_manager._keep_last_n_checkpoints = 1 | 
|  | 483 | +        self.assertTrue(ckpt_manager.should_save_checkpoint(ckpt)) | 
|  | 484 | + | 
|  | 485 | +        ckpt_manager._ckpt_paths = [ | 
|  | 486 | +            CheckpointPath( | 
|  | 487 | +                "foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.01) | 
|  | 488 | +            ), | 
|  | 489 | +        ] | 
|  | 490 | +        ckpt_manager._best_checkpoint_config = BestCheckpointConfig( | 
|  | 491 | +            monitored_metric="val_loss", | 
|  | 492 | +            mode="min", | 
|  | 493 | +        ) | 
|  | 494 | + | 
|  | 495 | +        bigger_metric = CheckpointPath( | 
|  | 496 | +            "foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.02) | 
|  | 497 | +        ) | 
|  | 498 | +        smaller_metric = CheckpointPath( | 
|  | 499 | +            "foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.001) | 
|  | 500 | +        ) | 
|  | 501 | +        ckpt_manager._keep_last_n_checkpoints = None | 
|  | 502 | +        self.assertTrue(ckpt_manager.should_save_checkpoint(bigger_metric)) | 
|  | 503 | +        ckpt_manager._keep_last_n_checkpoints = 1 | 
|  | 504 | +        self.assertFalse(ckpt_manager.should_save_checkpoint(bigger_metric)) | 
|  | 505 | +        self.assertTrue(ckpt_manager.should_save_checkpoint(smaller_metric)) | 
|  | 506 | +        ckpt_manager._keep_last_n_checkpoints = 2 | 
|  | 507 | +        self.assertTrue(ckpt_manager.should_save_checkpoint(smaller_metric)) | 
|  | 508 | +        self.assertTrue(ckpt_manager.should_save_checkpoint(bigger_metric)) | 
|  | 509 | + | 
|  | 510 | +        # Make sure we are actually comparing against more optimal element | 
|  | 511 | +        ckpt_manager._ckpt_paths = [ | 
|  | 512 | +            CheckpointPath( | 
|  | 513 | +                "foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.01) | 
|  | 514 | +            ), | 
|  | 515 | +            CheckpointPath( | 
|  | 516 | +                "foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.05) | 
|  | 517 | +            ), | 
|  | 518 | +        ] | 
|  | 519 | + | 
|  | 520 | +        ckpt_manager._best_checkpoint_config = BestCheckpointConfig( | 
|  | 521 | +            monitored_metric="val_loss", | 
|  | 522 | +            mode="max", | 
|  | 523 | +        ) | 
|  | 524 | +        ckpt_manager._keep_last_n_checkpoints = 2 | 
|  | 525 | +        self.assertTrue(ckpt_manager.should_save_checkpoint(bigger_metric)) | 
|  | 526 | + | 
|  | 527 | +    def test_remove_worst_checkpoint(self) -> None: | 
|  | 528 | +        with tempfile.TemporaryDirectory() as temp_dir: | 
|  | 529 | +            os.mkdir(os.path.join(temp_dir, "epoch_0_step_0")) | 
|  | 530 | +            os.mkdir(os.path.join(temp_dir, "epoch_0_step_1")) | 
|  | 531 | + | 
|  | 532 | +            ckpt_manager = CheckpointManager(temp_dir) | 
|  | 533 | +            ckpt_manager.append_checkpoint(CheckpointPath(temp_dir, 0, 0)) | 
|  | 534 | +            ckpt_manager.append_checkpoint(CheckpointPath(temp_dir, 0, 1)) | 
|  | 535 | + | 
|  | 536 | +            ckpt_manager.remove_checkpoint() | 
|  | 537 | +            self.assertFalse(os.path.exists(os.path.join(temp_dir, "epoch_0_step_0"))) | 
|  | 538 | +            self.assertTrue(os.path.exists(os.path.join(temp_dir, "epoch_0_step_1"))) | 
|  | 539 | +            self.assertEqual(ckpt_manager._ckpt_paths, [CheckpointPath(temp_dir, 0, 1)]) | 
|  | 540 | + | 
|  | 541 | + | 
| 193 | 542 | class CheckpointUtilsTest(unittest.TestCase): | 
| 194 | 543 |     @staticmethod | 
| 195 | 544 |     def _create_snapshot_metadata(output_dir: str) -> None: | 
| @@ -590,3 +939,12 @@ def test_metadata_exists(self) -> None: | 
| 590 | 939 | 
 | 
| 591 | 940 |             os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME)) | 
| 592 | 941 |             self.assertFalse(_metadata_exists(fs, dirpath, SNAPSHOT_METADATA_FNAME)) | 
|  | 942 | + | 
|  | 943 | + | 
|  | 944 | +class MyValLossUnit(TrainUnit[Batch]): | 
|  | 945 | +    def __init__(self) -> None: | 
|  | 946 | +        super().__init__() | 
|  | 947 | +        self.val_loss = 0.01 | 
|  | 948 | + | 
|  | 949 | +    def train_step(self, state: State, data: Batch) -> None: | 
|  | 950 | +        return None | 
0 commit comments