Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 71 additions & 14 deletions integration/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pytest

import firebase_admin
from firebase_admin import exceptions
from firebase_admin import ml
from tests import testutils
Expand All @@ -34,6 +35,11 @@
except ImportError:
_TF_ENABLED = False

try:
from google.cloud import automl_v1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once this is merged to master, let's also update the .github/workflows/release.yml to install this package, and add the Auto ML resource to the project Kobayashi Maru.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

_AUTOML_ENABLED = True
except ImportError:
_AUTOML_ENABLED = False

def _random_identifier(prefix):
#pylint: disable=unused-variable
Expand All @@ -42,27 +48,26 @@ def _random_identifier(prefix):


NAME_ONLY_ARGS = {
'display_name': _random_identifier('TestModel123_')
'display_name': _random_identifier('TestModel_')
}
NAME_ONLY_ARGS_UPDATED = {
'display_name': _random_identifier('TestModel123_updated_')
'display_name': _random_identifier('TestModel_updated_')
}
NAME_AND_TAGS_ARGS = {
'display_name': _random_identifier('TestModel123_tags_'),
'display_name': _random_identifier('TestModel_tags_'),
'tags': ['test_tag123']
}
FULL_MODEL_ARGS = {
'display_name': _random_identifier('TestModel123_full_'),
'display_name': _random_identifier('TestModel_full_'),
'tags': ['test_tag567'],
'file_name': 'model1.tflite'
}
INVALID_FULL_MODEL_ARGS = {
'display_name': _random_identifier('TestModel123_invalid_full_'),
'display_name': _random_identifier('TestModel_invalid_full_'),
'tags': ['test_tag890'],
'file_name': 'invalid_model.tflite'
}


@pytest.fixture
def firebase_model(request):
args = request.param
Expand Down Expand Up @@ -101,6 +106,7 @@ def _clean_up_model(model):
try:
# Try to delete the model.
# Some tests delete the model as part of the test.
model.wait_for_unlocked()
ml.delete_model(model.model_id)
except exceptions.NotFoundError:
pass
Expand Down Expand Up @@ -133,17 +139,20 @@ def check_model(model, args):
assert model.etag is not None


def check_model_format(model, has_model_format=False, validation_error=None):
def check_model_format(model, has_model_format=False, validation_error=None, is_automl=False):
if has_model_format:
assert model.validation_error == validation_error
assert model.published is False
assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://')
if validation_error:
assert model.model_format.size_bytes is None
assert model.model_hash is None
if is_automl:
assert model.model_format.model_source.auto_ml_model.startswith('projects/')
else:
assert model.model_format.size_bytes is not None
assert model.model_hash is not None
assert model.model_format.model_source.gcs_tflite_uri.startswith('gs://')
if validation_error:
assert model.model_format.size_bytes is None
assert model.model_hash is None
else:
assert model.model_format.size_bytes is not None
assert model.model_hash is not None
else:
assert model.model_format is None
assert model.validation_error == 'No model file has been uploaded.'
Expand Down Expand Up @@ -290,7 +299,7 @@ def test_delete_model(firebase_model):

# Test tensor flow conversion functions if tensor flow is enabled.
#'pip install tensorflow' in the environment if you want _TF_ENABLED = True
#'pip install tensorflow==2.0.0b' for version 2 etc.
#'pip install tensorflow==2.2.0' for version 2.2.0 etc.


def _clean_up_directory(save_dir):
Expand Down Expand Up @@ -334,6 +343,7 @@ def saved_model_dir(keras_model):
_clean_up_directory(parent)



@pytest.mark.skipif(not _TF_ENABLED, reason='Tensor flow is required for this test.')
def test_from_keras_model(keras_model):
source = ml.TFLiteGCSModelSource.from_keras_model(keras_model, 'model2.tflite')
Expand Down Expand Up @@ -371,3 +381,50 @@ def test_from_saved_model(saved_model_dir):
assert created_model.validation_error is None
finally:
_clean_up_model(created_model)


# Test AutoML functionality if AutoML is enabled.
#'pip install google-cloud-automl' in the environment if you want _AUTOML_ENABLED = True
# You will also need a predefined AutoML model named 'py_sdk_integ_test1' to run the
# successful test. (Test is skipped otherwise)

@pytest.fixture
def automl_model():
assert _AUTOML_ENABLED

# It takes > 20 minutes to train a model, so we expect a predefined AutoMl
# model named 'py_sdk_integ_test1' to exist in the project, or we skip
# the test.
automl_client = automl_v1.AutoMlClient()
project_id = firebase_admin.get_app().project_id
parent = automl_client.location_path(project_id, 'us-central1')
models = automl_client.list_models(parent, filter_="display_name=py_sdk_integ_test1")
# Expecting exactly one. (Ok to use last one if somehow more than 1)
automl_ref = None
for model in models:
automl_ref = model.name

# Skip if no pre-defined model. (It takes min > 20 minutes to train a model)
if automl_ref is None:
pytest.skip("No pre-existing AutoML model found. Skipping test")

source = ml.TFLiteAutoMlSource(automl_ref)
tflite_format = ml.TFLiteFormat(model_source=source)
ml_model = ml.Model(
display_name=_random_identifier('TestModel_automl_'),
tags=['test_automl'],
model_format=tflite_format)
model = ml.create_model(model=ml_model)
yield model
_clean_up_model(model)

@pytest.mark.skipif(not _AUTOML_ENABLED, reason='AutoML is required for this test.')
def test_automl_model(automl_model):
# This test looks for a predefined automl model with display_name = 'py_sdk_integ_test1'
automl_model.wait_for_unlocked()

check_model(automl_model, {
'display_name': automl_model.display_name,
'tags': ['test_automl'],
})
check_model_format(automl_model, has_model_format=True, validation_error=None, is_automl=True)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
astroid == 2.3.3
pylint == 2.3.1
pytest >= 3.6.0
pytest-cov >= 2.4.0
Expand Down