Skip to content
Merged
21 changes: 21 additions & 0 deletions firebase_admin/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,27 @@ def handle_platform_error_from_requests(error, handle_func=None):
return exc if exc else _handle_func_requests(error, message, error_dict)


def handle_operation_error(error):
"""Constructs a ``FirebaseError`` from the given operation error.

Args:
error: An error returned by a long running operation.

Returns:
FirebaseError: A ``FirebaseError`` that can be raised to the user code.
"""
if not isinstance(error, dict):
return exceptions.UnknownError(
message='Unknown error while making a remote service call: {0}'.format(error),
cause=error)

status_code = error.get('code')
message = error.get('message')
error_code = _http_status_to_error_code(status_code)
err_type = _error_code_to_exception_type(error_code)
return err_type(message=message)


def _handle_func_requests(error, message, error_dict):
"""Constructs a ``FirebaseError`` from the given GCP error.

Expand Down
62 changes: 62 additions & 0 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import datetime
import numbers
import re
import time
import requests
import six


from firebase_admin import _http_client
from firebase_admin import _utils

Expand All @@ -36,6 +38,8 @@
_GCS_TFLITE_URI_PATTERN = re.compile(r'^gs://[a-z0-9_.-]{3,63}/.+')
_RESOURCE_NAME_PATTERN = re.compile(
r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$')
_OPERATION_NAME_PATTERN = re.compile(
r'^operations/project/[^/]+/model/[A-Za-z0-9_-]{1,60}/operation/[^/]+$')


def _get_mlkit_service(app):
Expand All @@ -53,6 +57,11 @@ def _get_mlkit_service(app):
return _utils.get_app_service(app, _MLKIT_ATTRIBUTE, _MLKitService)


def create_model(model, app=None):
mlkit_service = _get_mlkit_service(app)
return Model.from_dict(mlkit_service.create_model(model))


def get_model(model_id, app=None):
mlkit_service = _get_mlkit_service(app)
return Model.from_dict(mlkit_service.get_model(model_id))
Expand Down Expand Up @@ -390,11 +399,23 @@ def _validate_and_parse_name(name):
return matcher.group('project_id'), matcher.group('model_id')


def _validate_model(model):
if not isinstance(model, Model):
raise TypeError('Model must be an mlkit.Model.')
if not model.display_name:
raise ValueError('Model must have a display name.')


def _validate_model_id(model_id):
if not _MODEL_ID_PATTERN.match(model_id):
raise ValueError('Model ID format is invalid.')


def _validate_operation_name(op_name):
if not _OPERATION_NAME_PATTERN.match(op_name):
raise ValueError('Operation name format is invalid.')


def _validate_display_name(display_name):
if not _DISPLAY_NAME_PATTERN.match(display_name):
raise ValueError('Display name format is invalid.')
Expand Down Expand Up @@ -448,6 +469,8 @@ class _MLKitService(object):
"""Firebase MLKit service."""

PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/'
OPERATION_URL = 'https://mlkit.googleapis.com/v1beta1/'
OPERATION_POLL_DELAY_SECONDS = 30

def __init__(self, app):
project_id = app.project_id
Expand All @@ -459,6 +482,45 @@ def __init__(self, app):
self._client = _http_client.JsonHttpClient(
credential=app.credential.get_credential(),
base_url=self._project_url)
self._operation_client = _http_client.JsonHttpClient(
credential=app.credential.get_credential(),
base_url=_MLKitService.OPERATION_URL)

def get_operation(self, op_name):
_validate_operation_name(op_name)
try:
return self._operation_client.body('get', url=op_name)
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def handle_operation(self, operation):
if not isinstance(operation, dict):
raise TypeError('Operation must be a dictionary.')
op_name = operation.get('name')
_validate_operation_name(op_name)

while True:
if operation.get('done'):
if operation.get('response'):
return operation.get('response')
elif operation.get('error'):
raise _utils.handle_operation_error(operation.get('error'))
else:
# A 'done' operation must have either a response or an error.
raise ValueError('Operation is malformed.')
else:
# We just got this operation wait 30s before getting another
# so we don't exceed the GetOperation maximum request rate.
time.sleep(_MLKitService.OPERATION_POLL_DELAY_SECONDS)
operation = self.get_operation(op_name)

def create_model(self, model):
_validate_model(model)
try:
return self.handle_operation(
self._client.body('post', url='models', json=model.as_dict()))
except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)

def get_model(self, model_id):
_validate_model_id(model_id)
Expand Down
Loading