Skip to content
This repository was archived by the owner on Sep 27, 2024. It is now read-only.

Commit 313b268

Browse files
authored
Make tests pytest and unittest compatible (#254)
1 parent f24c2c7 commit 313b268

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

model_card_toolkit/core_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
from unittest import mock
1818

19+
from absl import flags
1920
from absl.testing import absltest
2021
from absl.testing import parameterized
2122

@@ -289,3 +290,7 @@ def test_export_format_before_scaffold_assets(self):
289290

290291
if __name__ == '__main__':
291292
absltest.main()
293+
else:
294+
# Manually pass and parse flags to prevent UnparsedFlagAccessError when using
295+
# pytest or unittest as a runner.
296+
flags.FLAGS(['--test_tmpdir'])

model_card_toolkit/utils/testdata/tfxtest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
from typing import Any, Callable, List, Optional
1818

19+
from absl import flags
1920
import apache_beam as beam
2021
from model_card_toolkit.utils.tfx_util import _TFX_METRICS_TYPE
2122
from model_card_toolkit.utils.tfx_util import _TFX_STATS_TYPE
@@ -174,3 +175,9 @@ def _write(dataset_name: str, features: List[str], split_name: str):
174175

175176
if store:
176177
self._put_artifact(store, _TFX_STATS_TYPE, tfdv_path)
178+
179+
180+
if not __name__ == '__main__':
181+
# Manually pass and parse flags to prevent UnparsedFlagAccessError when using
182+
# pytest or unittest as a runner.
183+
flags.FLAGS(['--test_tmpdir'])

0 commit comments

Comments
 (0)