Skip to content

Commit

Permalink
Add checking for existing publisher documentation in validator.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 294421295
  • Loading branch information
TensorFlow Hub Authors authored and vbardiovskyg committed Feb 11, 2020
1 parent d1611a3 commit 16ae221
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 15 deletions.
34 changes: 28 additions & 6 deletions tfhub_dev/tools/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def get_contents(self, filename):
with tf.io.gfile.GFile(filename, "r") as f:
return f.read()

def file_exists(self, filename):
"""Returns whether file exists."""
return tf.io.gfile.exists(filename)

def recursive_list_dir(self, root_dir):
"""Yields all files of a root directory tree."""
for dirname, _, filenames in tf.io.gfile.walk(root_dir):
Expand Down Expand Up @@ -110,6 +114,10 @@ def type_name(self):
def handle(self):
return "%s/%s/%s" % (self._publisher, self._model_name, self._model_version)

@property
def publisher(self):
return self._publisher

def get_expected_location(self, root_dir):
"""Returns the expected path of a documentation file."""
del root_dir
Expand Down Expand Up @@ -183,8 +191,9 @@ def get_required_metadata(self):
class DocumentationParser(object):
"""Class used for parsing model documentation strings."""

def __init__(self, documentation_dir):
def __init__(self, documentation_dir, filesystem):
self._documentation_dir = documentation_dir
self._filesystem = filesystem
self._parsed_metadata = dict()
self._parsed_description = ""

Expand Down Expand Up @@ -224,6 +233,19 @@ def consume_first_line(self):
"is \"%s\"." % (MODEL_HANDLE_PATTERN, PUBLISHER_HANDLE_PATTERN,
COLLECTION_HANDLE_PATTERN, first_line))

def assert_publisher_page_exists(self):
"""Assert that publisher page exists for the publisher of this model."""
# Use a publisher policy to get the expected documentation page path.
publisher_policy = PublisherParsingPolicy(self._parsing_policy.publisher,
self._parsing_policy.publisher,
None)
expected_publisher_doc_location = publisher_policy.get_expected_location(
self._documentation_dir)
if not self._filesystem.file_exists(expected_publisher_doc_location):
self.raise_error(
"Publisher documentation does not exist. It should be added to %s." %
expected_publisher_doc_location)

def assert_correct_location(self):
"""Assert that documentation file is submitted to a correct location."""
expected_file_path = self._parsing_policy.get_expected_location(
Expand Down Expand Up @@ -348,9 +370,9 @@ def smoke_test_asset(self):
"README.md. Underlying reason for failure: %s." %
(asset_path, reason))

def validate(self, file_path, documentation_content, do_smoke_test):
def validate(self, file_path, do_smoke_test):
"""Validate one documentation markdown file."""
self._raw_content = documentation_content
self._raw_content = self._filesystem.get_contents(file_path)
self._lines = self._raw_content.split("\n")
self._file_path = file_path
self.consume_first_line()
Expand All @@ -359,6 +381,7 @@ def validate(self, file_path, documentation_content, do_smoke_test):
self.consume_metadata()
self.assert_correct_metadata()
self.assert_allowed_license()
self.assert_publisher_page_exists()
if do_smoke_test:
self.smoke_test_asset()

Expand All @@ -374,10 +397,9 @@ def validate_documentation_files(documentation_dir,
if files_to_validate and file_path[len(documentation_dir) +
1:] not in files_to_validate:
continue
file_content = filesystem.get_contents(file_path)
logging.info("Validating %s.", file_path)
documentation_parser = DocumentationParser(documentation_dir)
documentation_parser.validate(file_path, file_content, do_smoke_test)
documentation_parser = DocumentationParser(documentation_dir, filesystem)
documentation_parser.validate(file_path, do_smoke_test)
validated += 1
logging.info("Found %d matching files - all validated successfully.",
validated)
Expand Down
55 changes: 46 additions & 9 deletions tfhub_dev/tools/validator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def __init__(self):
def get_contents(self, filename):
return self._files[filename]

def file_exists(self, filename):
"""Returns whether file exists."""
return filename in self._files

def set_contents(self, filename, contents):
self._files[filename] = contents

Expand All @@ -52,10 +56,22 @@ def recursive_list_dir(self, root_dir):
## Overview
"""

MINIMAL_MARKDOWN_WITH_UNKNOWN_PUBLISHER = """# Module publisher-without-page/text-embedding-model/1
Simple description spanning
multiple lines.
<!-- asset-path: /path/to/model -->
<!-- module-type: text-embedding -->
<!-- fine-tunable:true -->
<!-- format: saved_model_2 -->
## Overview
"""

MINIMAL_MARKDOWN_WITH_ALLOWED_LICENSE = """# Module google/model/1
Simple description.
<!-- asset-path: %s -->
<!-- asset-path: /path/to/model -->
<!-- module-type: text-embedding -->
<!-- fine-tunable: true -->
<!-- format: saved_model_2 -->
Expand Down Expand Up @@ -123,7 +139,7 @@ def recursive_list_dir(self, root_dir):
## Overview
"""

MINIMAL_PUBLISHER_MARKDOWN = """# Publisher some-publisher
MINIMAL_PUBLISHER_MARKDOWN = """# Publisher %s
Simple description spanning one line.
[![Icon URL]](https://path/to/icon.png)
Expand All @@ -148,6 +164,10 @@ def tearDown(self):
super(tf.test.TestCase, self).tearDown()
shutil.rmtree(self.tmp_dir)

def set_up_publisher_page(self, filesystem, publisher):
filesystem.set_contents("root/%s/%s.md" % (publisher, publisher),
MINIMAL_PUBLISHER_MARKDOWN % publisher)

def save_dummy_model(self, path):

class MultiplyTimesTwoModel(tf.train.Checkpoint):
Expand Down Expand Up @@ -176,13 +196,15 @@ def test_minimal_markdown_parsed(self):
filesystem = MockFilesystem()
filesystem.set_contents("root/google/models/text-embedding-model/1.md",
self.minimal_markdown)
self.set_up_publisher_page(filesystem, "google")
validator.validate_documentation_files(
documentation_dir="root", filesystem=filesystem)

def test_minimal_markdown_parsed_with_selected_files(self):
filesystem = MockFilesystem()
filesystem.set_contents("root/google/models/text-embedding-model/1.md",
self.minimal_markdown)
self.set_up_publisher_page(filesystem, "google")
num_validated = validator.validate_documentation_files(
documentation_dir="root",
files_to_validate=["google/models/text-embedding-model/1.md"],
Expand All @@ -194,13 +216,13 @@ def test_minimal_collection_markdown_parsed(self):
filesystem.set_contents(
"root/google/collections/text-embedding-collection/1.md",
MINIMAL_COLLECTION_MARKDOWN)
self.set_up_publisher_page(filesystem, "google")
validator.validate_documentation_files(
documentation_dir="root", filesystem=filesystem)

def test_minimal_publisher_markdown_parsed(self):
filesystem = MockFilesystem()
filesystem.set_contents("root/some-publisher/some-publisher.md",
MINIMAL_PUBLISHER_MARKDOWN)
self.set_up_publisher_page(filesystem, "some-publisher")
validator.validate_documentation_files(
documentation_dir="root", filesystem=filesystem)

Expand All @@ -221,6 +243,16 @@ def test_minimal_markdown_not_in_publisher_dir(self):
validator.validate_documentation_files(
documentation_dir="root", filesystem=filesystem)

def test_fails_if_publisher_page_does_not_exist(self):
filesystem = MockFilesystem()
filesystem.set_contents(
"root/publisher-without-page/models/text-embedding-model/1.md",
MINIMAL_MARKDOWN_WITH_UNKNOWN_PUBLISHER)
with self.assertRaisesRegexp(validator.MarkdownDocumentationError,
".*Publisher documentation does not.*"):
validator.validate_documentation_files(
documentation_dir="root", filesystem=filesystem)

def test_minimal_markdown_does_not_end_with_md_fails(self):
filesystem = MockFilesystem()
filesystem.set_contents("root/google/models/wrong-extension/1.mdz",
Expand All @@ -233,16 +265,15 @@ def test_minimal_markdown_does_not_end_with_md_fails(self):
def test_publisher_markdown_at_incorrect_location_fails(self):
filesystem = MockFilesystem()
filesystem.set_contents("root/google/publisher.md",
MINIMAL_PUBLISHER_MARKDOWN)
MINIMAL_PUBLISHER_MARKDOWN % "some-publisher")
with self.assertRaisesRegexp(validator.MarkdownDocumentationError,
r".*some-publisher\.md.*"):
validator.validate_documentation_files(
documentation_dir="root", filesystem=filesystem)

def test_publisher_markdown_at_correct_location(self):
filesystem = MockFilesystem()
filesystem.set_contents("root/some-publisher/some-publisher.md",
MINIMAL_PUBLISHER_MARKDOWN)
self.set_up_publisher_page(filesystem, "some-publisher")
validator.validate_documentation_files(
documentation_dir="root", filesystem=filesystem)

Expand Down Expand Up @@ -283,10 +314,13 @@ def test_markdown_with_unexpected_lines(self):
documentation_dir="root", filesystem=filesystem)

def test_minimal_markdown_parsed_full(self):
documentation_parser = validator.DocumentationParser("root")
filesystem = MockFilesystem()
filesystem.set_contents("root/google/models/text-embedding-model/1.md",
self.minimal_markdown)
self.set_up_publisher_page(filesystem, "google")
documentation_parser = validator.DocumentationParser("root", filesystem)
documentation_parser.validate(
file_path="root/google/models/text-embedding-model/1.md",
documentation_content=self.minimal_markdown,
do_smoke_test=True)
self.assertEqual("Simple description spanning multiple lines.",
documentation_parser.parsed_description)
Expand All @@ -302,6 +336,7 @@ def test_bad_model_does_not_pass_smoke_test(self):
filesystem = MockFilesystem()
filesystem.set_contents("root/google/models/text-embedding-model/1.md",
self.minimal_markdown_with_bad_model)
self.set_up_publisher_page(filesystem, "google")
with self.assertRaisesRegexp(validator.MarkdownDocumentationError,
".*failed to parse.*"):
validator.validate_documentation_files(
Expand All @@ -313,13 +348,15 @@ def test_markdown_with_allowed_license(self):
filesystem = MockFilesystem()
filesystem.set_contents("root/google/models/model/1.md",
MINIMAL_MARKDOWN_WITH_ALLOWED_LICENSE)
self.set_up_publisher_page(filesystem, "google")
validator.validate_documentation_files(
documentation_dir="root", filesystem=filesystem)

def test_markdown_with_unknown_license(self):
filesystem = MockFilesystem()
filesystem.set_contents("root/google/models/model/1.md",
MINIMAL_MARKDOWN_WITH_UNKNOWN_LICENSE)
self.set_up_publisher_page(filesystem, "google")
with self.assertRaisesRegexp(validator.MarkdownDocumentationError,
".*specify a license id from list.*"):
validator.validate_documentation_files(
Expand Down

0 comments on commit 16ae221

Please sign in to comment.