Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 154 additions & 11 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
deleting, publishing and unpublishing Firebase ML Kit models.
"""



import datetime
import numbers
import re
Expand All @@ -30,13 +32,27 @@
from firebase_admin import _utils
from firebase_admin import exceptions

# pylint: disable=import-error,no-name-in-module
try:
from firebase_admin import storage
GCS_ENABLED = True
except ImportError:
GCS_ENABLED = False

# pylint: disable=import-error,no-name-in-module
try:
import tensorflow as tf
TF_ENABLED = True
except ImportError:
TF_ENABLED = False

_MLKIT_ATTRIBUTE = '_mlkit'
_MAX_PAGE_SIZE = 100
_MODEL_ID_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
_DISPLAY_NAME_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
_TAG_PATTERN = re.compile(r'^[A-Za-z0-9_-]{1,60}$')
_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+')
_GCS_TFLITE_URI_PATTERN = re.compile(
r'^gs://(?P<bucket_name>[a-z0-9_.-]{3,63})/(?P<blob_name>.+)$')
_RESOURCE_NAME_PATTERN = re.compile(
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
_OPERATION_NAME_PATTERN = re.compile(
Expand Down Expand Up @@ -301,16 +317,16 @@ def model_format(self, model_format):
self._model_format = model_format #Can be None
return self

def as_dict(self):
def as_dict(self, for_upload=False):
copy = dict(self._data)
if self._model_format:
copy.update(self._model_format.as_dict())
copy.update(self._model_format.as_dict(for_upload=for_upload))
return copy


class ModelFormat(object):
"""Abstract base class representing a Model Format such as TFLite."""
def as_dict(self):
def as_dict(self, for_upload=False):
raise NotImplementedError


Expand Down Expand Up @@ -364,22 +380,59 @@ def model_source(self, model_source):
def size_bytes(self):
return self._data.get('sizeBytes')

def as_dict(self):
def as_dict(self, for_upload=False):
copy = dict(self._data)
if self._model_source:
copy.update(self._model_source.as_dict())
copy.update(self._model_source.as_dict(for_upload=for_upload))
return {'tfliteModel': copy}


class TFLiteModelSource(object):
"""Abstract base class representing a model source for TFLite format models."""
def as_dict(self):
def as_dict(self, for_upload=False):
raise NotImplementedError


class _CloudStorageClient(object):
"""Cloud Storage helper class"""

GCS_URI = 'gs://{0}/{1}'
BLOB_NAME = 'Firebase/MLKit/Models/{0}'

@staticmethod
def upload(bucket_name, model_file_name, app):
if not GCS_ENABLED:
raise ImportError('Failed to import the Cloud Storage library for Python. Make sure '
'to install the "google-cloud-storage" module.')
bucket = storage.bucket(bucket_name, app=app)
blob_name = _CloudStorageClient.BLOB_NAME.format(model_file_name)
blob = bucket.blob(blob_name)
blob.upload_from_filename(model_file_name)
return _CloudStorageClient.GCS_URI.format(bucket.name, blob_name)

@staticmethod
def sign_uri(gcs_tflite_uri, app):
"""Makes the gcs_tflite_uri readable for GET for 10 minutes."""
if not GCS_ENABLED:
raise ImportError('Failed to import the Cloud Storage library for Python. Make sure '
'to install the "google-cloud-storage" module.')
bucket_name, blob_name = _parse_gcs_tflite_uri(gcs_tflite_uri)
bucket = storage.bucket(bucket_name, app=app)
blob = bucket.blob(blob_name)
return blob.generate_signed_url(
version='v4',
expiration=datetime.timedelta(minutes=10),
method='GET'
)


class TFLiteGCSModelSource(TFLiteModelSource):
"""TFLite model source representing a tflite model file stored in GCS."""
def __init__(self, gcs_tflite_uri):

_STORAGE_CLIENT = _CloudStorageClient()

def __init__(self, gcs_tflite_uri, app=None):
self._app = app
self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri)

def __eq__(self, other):
Expand All @@ -391,6 +444,79 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)

@classmethod
def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None):
"""Uploads the model file to an existing Google Cloud Services bucket.

Args:
model_file_name: The name of the model file.
bucket_name: The name of an existing bucket. None to use the default bucket configured
in the app.
app: A Firebase app instance (or None to use the default app).

Returns:
TFLiteGCSModelSource: The source created from the model_file

Raises:
ImportError: If the Cloud Storage Library has not been installed.
"""
gcs_uri = TFLiteGCSModelSource._STORAGE_CLIENT.upload(bucket_name, model_file_name, app)
return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app)

@classmethod
def from_saved_model(cls, saved_model_dir, bucket_name=None, app=None):
"""Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS.

Args:
saved_model_dir: The saved model directory.
bucket_name: Optional. The name of the bucket to store the uploaded tflite file.
(or None to use the default bucket)
app: Optional. A Firebase app instance (or None to use the default app)

Returns:
TFLiteGCSModelSource: The source created from the saved_model_dir

Raises:
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
"""
if not TF_ENABLED:
raise ImportError('Failed to import the tensorflow library for Python. Make sure '
'to install the tensorflow module.')
#TODO(ifielker): Do we need to worry about tf version?
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
open("firebase_mlkit_model.tflite", "wb").write(tflite_model)
return TFLiteGCSModelSource.from_tflite_model_file(
"firebase_mlkit_model.tflite", bucket_name, app)

@classmethod
def from_keras_model(cls, keras_model, bucket_name=None, app=None):
"""Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS.

Args:
keras_model: A tf.keras model.
bucket_name: Optional. The name of the bucket to store the uploaded tflite file.
(or None to use the default bucket)
app: Optional. A Firebase app instance (or None to use the default app)

Returns:
TFLiteGCSModelSource: The source created from the keras_model

Raises:
ImportError: If the Tensor Flow or Cloud Storage Libraries have not been installed.
"""
if not TF_ENABLED:
raise ImportError('Failed to import the tensorflow library for Python. Make sure '
'to install the tensorflow module.')
#TODO(ifielker): Do we need to worry about tf version?
keras_file = "keras_model.h5"
tf.keras.models.save_model(keras_model, keras_file)
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
open("firebase_mlkit_model.tflite", "wb").write(tflite_model)
return TFLiteGCSModelSource.from_tflite_model_file(
"firebase_mlkit_model.tflite", bucket_name, app)

@property
def gcs_tflite_uri(self):
return self._gcs_tflite_uri
Expand All @@ -399,9 +525,17 @@ def gcs_tflite_uri(self):
def gcs_tflite_uri(self, gcs_tflite_uri):
self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri)

def as_dict(self):
def _get_signed_gcs_tflite_uri(self):
"""Signs the GCS uri, so the model file can be uploaded to Firebase ML Kit and verified."""
return TFLiteGCSModelSource._STORAGE_CLIENT.sign_uri(self._gcs_tflite_uri, self._app)

def as_dict(self, for_upload=False):
if for_upload:
return {"gcsTfliteUri": self._get_signed_gcs_tflite_uri()}

return {"gcsTfliteUri": self._gcs_tflite_uri}


#TODO(ifielker): implement from_saved_model etc.


Expand Down Expand Up @@ -553,6 +687,15 @@ def _validate_gcs_tflite_uri(uri):
return uri


def _parse_gcs_tflite_uri(uri):
# GCS Bucket naming rules are complex. The regex is not comprehensive.
# See https://cloud.google.com/storage/docs/naming for full details.
matcher = _GCS_TFLITE_URI_PATTERN.match(uri)
if not matcher:
raise ValueError('GCS TFLite URI format is invalid.')
return matcher.group('bucket_name'), matcher.group('blob_name')


def _validate_model_format(model_format):
if not isinstance(model_format, ModelFormat):
raise TypeError('Model format must be a ModelFormat object.')
Expand Down Expand Up @@ -671,13 +814,13 @@ def create_model(self, model):
_validate_model(model)
try:
return self.handle_operation(
self._client.body('post', url='models', json=model.as_dict()))
self._client.body('post', url='models', json=model.as_dict(for_upload=True)))
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def update_model(self, model, update_mask=None):
_validate_model(model, update_mask)
data = {'model': model.as_dict()}
data = {'model': model.as_dict(for_upload=True)}
if update_mask is not None:
data['updateMask'] = update_mask
try:
Expand Down
35 changes: 35 additions & 0 deletions tests/test_mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@
}
TFLITE_FORMAT = mlkit.TFLiteFormat.from_dict(TFLITE_FORMAT_JSON)

GCS_TFLITE_SIGNED_URI = 'gs://test_bucket/test_blob?signing_information'
GCS_TFLITE_SIGNED_URI_JSON = {'gcsTfliteUri': GCS_TFLITE_URI}
GCS_TFLITE_SIGNED_MODEL_SOURCE = mlkit.TFLiteGCSModelSource(GCS_TFLITE_SIGNED_URI)

GCS_TFLITE_URI_2 = 'gs://my_bucket/mymodel2.tflite'
GCS_TFLITE_URI_JSON_2 = {'gcsTfliteUri': GCS_TFLITE_URI_2}
GCS_TFLITE_MODEL_SOURCE_2 = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI_2)
Expand Down Expand Up @@ -325,6 +329,17 @@ def instrument_mlkit_service(status=200, payload=None, operations=False, app=Non
session_url, adapter(payload, status, recorder))
return recorder

class _TestStorageClient(object):
@staticmethod
def upload(bucket_name, model_file_name, app):
del app # unused variable
blob_name = mlkit._CloudStorageClient.BLOB_NAME.format(model_file_name)
return mlkit._CloudStorageClient.GCS_URI.format(bucket_name, blob_name)

@staticmethod
def sign_uri(gcs_tflite_uri, app):
del gcs_tflite_uri, app # unused variables
return GCS_TFLITE_SIGNED_URI

class TestModel(object):
"""Tests mlkit.Model class."""
Expand All @@ -333,6 +348,7 @@ def setup_class(cls):
cred = testutils.MockCredential()
firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID})
mlkit._MLKitService.POLL_BASE_WAIT_TIME_SECONDS = 0.1 # shorter for test
mlkit.TFLiteGCSModelSource._STORAGE_CLIENT = _TestStorageClient()

@classmethod
def teardown_class(cls):
Expand Down Expand Up @@ -404,6 +420,13 @@ def test_model_format_source_creation(self):
}
}

def test_source_creation_from_tflite_file(self):
model_source = mlkit.TFLiteGCSModelSource.from_tflite_model_file(
"my_model.tflite", "my_bucket")
assert model_source.as_dict() == {
'gcsTfliteUri': 'gs://my_bucket/Firebase/MLKit/Models/my_model.tflite'
}

def test_model_source_setters(self):
model_source = mlkit.TFLiteGCSModelSource(GCS_TFLITE_URI)
model_source.gcs_tflite_uri = GCS_TFLITE_URI_2
Expand All @@ -420,6 +443,17 @@ def test_model_format_setters(self):
}
}

def test_model_as_dict_for_upload(self):
model_source = mlkit.TFLiteGCSModelSource(gcs_tflite_uri=GCS_TFLITE_URI)
model_format = mlkit.TFLiteFormat(model_source=model_source)
model = mlkit.Model(display_name=DISPLAY_NAME_1, model_format=model_format)
assert model.as_dict(for_upload=True) == {
'displayName': DISPLAY_NAME_1,
'tfliteModel': {
'gcsTfliteUri': GCS_TFLITE_SIGNED_URI
}
}

@pytest.mark.parametrize('display_name, exc_type', [
('', ValueError),
('&_*#@:/?', ValueError),
Expand Down Expand Up @@ -803,6 +837,7 @@ def test_rpc_error(self, publish_function):
)
assert len(create_recorder) == 1


class TestGetModel(object):
"""Tests mlkit.get_model."""
@classmethod
Expand Down