Skip to content
Merged
21 changes: 21 additions & 0 deletions firebase_admin/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,27 @@ def handle_platform_error_from_requests(error, handle_func=None):
return exc if exc else _handle_func_requests(error, message, error_dict)


def handle_operation_error(error):
"""Constructs a ``FirebaseError`` from the given operation error.

Args:
error: An error returned by a long running operation.

Returns:
FirebaseError: A ``FirebaseError`` that can be raised to the user code.
"""
if not isinstance(error, dict):
return exceptions.UnknownError(
message='Unknown error while making a remote service call: {0}'.format(error),
cause=error)

status_code = error.get('code')
message = error.get('message')
error_code = _http_status_to_error_code(status_code)
err_type = _error_code_to_exception_type(error_code)
return err_type(message=message)


def _handle_func_requests(error, message, error_dict):
"""Constructs a ``FirebaseError`` from the given GCP error.

Expand Down
119 changes: 119 additions & 0 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
import datetime
import numbers
import re
import time
import requests
import six


from firebase_admin import _http_client
from firebase_admin import _utils
from firebase_admin import exceptions


_MLKIT_ATTRIBUTE = '_mlkit'
Expand All @@ -36,6 +39,8 @@
_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+')
_RESOURCE_NAME_PATTERN = re.compile(
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
_OPERATION_NAME_PATTERN = re.compile(
r'^operations/project/[^/]+/model/[A-Za-z0-9_-]{1,60}/operation/[^/]+$')


def _get_mlkit_service(app):
Expand All @@ -53,18 +58,60 @@ def _get_mlkit_service(app):
return _utils.get_app_service(app, _MLKIT_ATTRIBUTE, _MLKitService)


def create_model(model, app=None):
"""Creates a model in Firebase ML Kit.

Args:
model: An mlkit.Model to create.
app: A Firebase app instance (or None to use the default app).

Returns:
Model: The model that was created in Firebase ML Kit.
"""
mlkit_service = _get_mlkit_service(app)
return Model.from_dict(mlkit_service.create_model(model))


def get_model(model_id, app=None):
"""Gets a model from Firebase ML Kit.

Args:
model_id: The id of the model to get.
app: A Firebase app instance (or None to use the default app).

Returns:
Model: The requested model.
"""
mlkit_service = _get_mlkit_service(app)
return Model.from_dict(mlkit_service.get_model(model_id))


def list_models(list_filter=None, page_size=None, page_token=None, app=None):
"""Lists models from Firebase ML Kit.

Args:
list_filter: a list filter string such as "tags:'tag_1'". None will return all models.
page_size: A number between 1 and 100 inclusive that specifies the maximum
number of models to return per page. None for default.
page_token: A next page token returned from a previous page of results. None
for first page of results.
app: A Firebase app instance (or None to use the default app).

Returns:
ListModelsPage: A (filtered) list of models.
"""
mlkit_service = _get_mlkit_service(app)
return ListModelsPage(
mlkit_service.list_models, list_filter, page_size, page_token)


def delete_model(model_id, app=None):
"""Deletes a model from Firebase ML Kit.

Args:
model_id: The id of the model you wish to delete.
app: A Firebase app instance (or None to use the default app).
"""
mlkit_service = _get_mlkit_service(app)
mlkit_service.delete_model(model_id)

Expand Down Expand Up @@ -390,11 +437,23 @@ def _validate_and_parse_name(name):
return matcher.group('project_id'), matcher.group('model_id')


def _validate_model(model):
if not isinstance(model, Model):
raise TypeError('Model must be an mlkit.Model.')
if not model.display_name:
raise ValueError('Model must have a display name.')


def _validate_model_id(model_id):
if not _MODEL_ID_PATTERN.match(model_id):
raise ValueError('Model ID format is invalid.')


def _validate_operation_name(op_name):
if not _OPERATION_NAME_PATTERN.match(op_name):
raise ValueError('Operation name format is invalid.')


def _validate_display_name(display_name):
if not _DISPLAY_NAME_PATTERN.match(display_name):
raise ValueError('Display name format is invalid.')
Expand Down Expand Up @@ -448,6 +507,11 @@ class _MLKitService(object):
"""Firebase MLKit service."""

PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/'
OPERATION_URL = 'https://mlkit.googleapis.com/v1beta1/'
OPERATION_POLL_DELAY_SECONDS = 30
MAX_POLLING_ATTEMPTS = 10
POLL_EXPONENTIAL_BACKOFF_FACTOR = 2
POLL_BASE_WAIT_TIME_SECONDS = 1

def __init__(self, app):
project_id = app.project_id
Expand All @@ -459,6 +523,61 @@ def __init__(self, app):
self._client = _http_client.JsonHttpClient(
credential=app.credential.get_credential(),
base_url=self._project_url)
self._operation_client = _http_client.JsonHttpClient(
credential=app.credential.get_credential(),
base_url=_MLKitService.OPERATION_URL)

def get_operation(self, op_name):
_validate_operation_name(op_name)
try:
return self._operation_client.body('get', url=op_name)
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def handle_operation(self, operation):
"""Handles long running operations.

Args:
operation: The operation to handle.

Returns:
dict: A dictionary of the returned model properties.

Raises:
TypeError: if the operation is not a dictionary.
ValueError: If the operation is malformed.
"""
if not isinstance(operation, dict):
raise TypeError('Operation must be a dictionary.')
op_name = operation.get('name')
_validate_operation_name(op_name)

for current_attempt in range(_MLKitService.MAX_POLLING_ATTEMPTS):
if operation.get('done'):
if operation.get('response'):
return operation.get('response')
elif operation.get('error'):
raise _utils.handle_operation_error(operation.get('error'))
else:
# A 'done' operation must have either a response or an error.
raise ValueError('Operation is malformed.')
else:
# We just got this operation. Wait before getting another
# so we don't exceed the GetOperation maximum request rate.
delay_factor = pow(
_MLKitService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt)
wait_time_seconds = delay_factor * _MLKitService.POLL_BASE_WAIT_TIME_SECONDS
time.sleep(wait_time_seconds)
operation = self.get_operation(op_name)
raise exceptions.DeadlineExceededError('Polling deadline exceeded.')

def create_model(self, model):
_validate_model(model)
try:
return self.handle_operation(
self._client.body('post', url='models', json=model.as_dict()))
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def get_model(self, model_id):
_validate_model_id(model_id)
Expand Down
Loading