Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ before_install:
- nvm install 8 && npm install -g firebase-tools
script:
- pytest
- firebase emulators:exec --only database --project fake-project-id 'pytest integration/test_db.py'
cache:
pip: true
npm: true
Expand Down
73 changes: 71 additions & 2 deletions firebase_admin/mlkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,77 @@
deleting, publishing and unpublishing Firebase ML Kit models.
"""

import re
import requests
import six

from firebase_admin import _http_client
from firebase_admin import _utils


_MLKIT_ATTRIBUTE = '_mlkit'


def _get_mlkit_service(app):
""" Returns an _MLKitService instance for an App.
Args:
app: A Firebase App instance (or None to use the default App).
Returns:
_MLKitService: An _MLKitService for the specified App instance.
Raises:
ValueError: If the app argument is invalid.
"""
return _utils.get_app_service(app, _MLKIT_ATTRIBUTE, _MLKitService)


def get_model(model_id, app=None):
mlkit_service = _get_mlkit_service(app)
return Model(mlkit_service.get_model(model_id))


class Model(object):
"""A Firebase ML Kit Model object."""
def __init__(self, data):
"""Created from a data dictionary."""
self._data = data

def __eq__(self, other):
if isinstance(other, self.__class__):
return self._data == other._data # pylint: disable=protected-access
else:
return False

def __ne__(self, other):
return not self.__eq__(other)

#TODO(ifielker): define the Model properties etc


class _MLKitService(object):
"""Firebase MLKit service."""

BASE_URL = 'https://mlkit.googleapis.com'
PROJECT_URL = 'https://mlkit.googleapis.com/projects/{0}/'
PROJECT_URL = 'https://mlkit.googleapis.com/v1beta1/projects/{0}/'

def __init__(self, app):
project_id = app.project_id
if not project_id:
raise ValueError(
'Project ID is required to access MLKit service. Either set the '
'projectId option, or use service account credentials.')
self._project_url = _MLKitService.PROJECT_URL.format(project_id)
self._client = _http_client.JsonHttpClient(credential=app.credential.get_credential())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also pass base_url=self._project_url to the constructor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def get_model(self, model_id):
if not isinstance(model_id, six.string_types):
raise TypeError('Model ID must be a string.')
if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id):
raise ValueError('Model ID format is invalid.')
try:
return self._client.body(
'get',
url=self._project_url + 'models/{0}'.format(model_id))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once you pass base_url the constructor, you don't have to do this concat here. Just pass the path modles/... portion as the url.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

except requests.exceptions.RequestException as error:
raise _utils.handle_platform_error_from_requests(error)
134 changes: 134 additions & 0 deletions tests/test_mlkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2019 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test cases for the firebase_admin.mlkit module."""

import json
import pytest

import firebase_admin
from firebase_admin import exceptions
from firebase_admin import mlkit
from tests import testutils

BASE_URL = 'https://mlkit.googleapis.com/v1beta1/'

PROJECT_ID = 'myProject1'
MODEL_ID_1 = 'modelId1'
MODEL_NAME_1 = 'projects/{0}/models/{1}'.format(PROJECT_ID, MODEL_ID_1)
DISPLAY_NAME_1 = 'displayName1'
MODEL_JSON_1 = {
'name': MODEL_NAME_1,
'displayName': DISPLAY_NAME_1
}
MODEL_1 = mlkit.Model(MODEL_JSON_1)
_DEFAULT_RESPONSE = json.dumps(MODEL_JSON_1)

ERROR_CODE = 404
ERROR_MSG = 'The resource was not found'
ERROR_STATUS = 'NOT_FOUND'
ERROR_JSON = {
'error': {
'code': ERROR_CODE,
'message': ERROR_MSG,
'status': ERROR_STATUS
}
}
_ERROR_RESPONSE = json.dumps(ERROR_JSON)


class TestGetModel(object):
"""Tests mlkit.get_model."""
@classmethod
def setup_class(cls):
cred = testutils.MockCredential()
firebase_admin.initialize_app(cred, {'projectId': PROJECT_ID})

@classmethod
def teardown_class(cls):
testutils.cleanup_apps()

@staticmethod
def check_error(err, err_type, msg):
assert isinstance(err, err_type)
assert str(err) == msg

@staticmethod
def check_firebase_error(err, code, status, msg):
assert isinstance(err, exceptions.FirebaseError)
assert err.code == code
assert err.http_response is not None
assert err.http_response.status_code == status
assert str(err) == msg

def _get_url(self, project_id, model_id):
return BASE_URL + 'projects/{0}/models/{1}'.format(project_id, model_id)

def _instrument_mlkit_service(self, app=None, status=200, payload=_DEFAULT_RESPONSE):
if not app:
app = firebase_admin.get_app()
mlkit_service = mlkit._get_mlkit_service(app)
recorder = []
mlkit_service._client.session.mount(
'https://mlkit.googleapis.com',
testutils.MockAdapter(payload, status, recorder)
)
return mlkit_service, recorder

def test_get_model(self):
_, recorder = self._instrument_mlkit_service()
model = mlkit.get_model(MODEL_ID_1)
assert len(recorder) == 1
assert recorder[0].method == 'GET'
assert recorder[0].url == self._get_url(PROJECT_ID, MODEL_ID_1)
assert model == MODEL_1
assert model._data['name'] == MODEL_NAME_1
assert model._data['displayName'] == DISPLAY_NAME_1

def test_get_model_validation_errors(self):
#Empty model-id
with pytest.raises(ValueError) as err:
mlkit.get_model('')
self.check_error(err.value, ValueError, 'Model ID format is invalid.')

#None model-id
with pytest.raises(TypeError) as err:
mlkit.get_model(None)
self.check_error(err.value, TypeError, 'Model ID must be a string.')

#Wrong type
with pytest.raises(TypeError) as err:
mlkit.get_model(12345)
self.check_error(err.value, TypeError, 'Model ID must be a string.')

#Invalid characters
with pytest.raises(ValueError) as err:
mlkit.get_model('&_*#@:/?')
self.check_error(err.value, ValueError, 'Model ID format is invalid.')

def test_get_model_error(self):
_, recorder = self._instrument_mlkit_service(status=404, payload=_ERROR_RESPONSE)
with pytest.raises(exceptions.NotFoundError) as err:
mlkit.get_model(MODEL_ID_1)
self.check_firebase_error(err.value, ERROR_STATUS, ERROR_CODE, ERROR_MSG)
assert len(recorder) == 1
assert recorder[0].method == 'GET'
assert recorder[0].url == self._get_url(PROJECT_ID, MODEL_ID_1)

def test_no_project_id(self):
def evaluate():
app = firebase_admin.initialize_app(testutils.MockCredential(), name='no_project_id')
with pytest.raises(ValueError):
mlkit.get_model(MODEL_ID_1, app)
testutils.run_without_project_id(evaluate)