-
Couldn't load subscription status.
- Fork 343
Implementation of Model, ModelFormat, TFLiteModelSource and subclasses #335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
573e7cb
e67bceb
1f018fe
dfe0a37
7704c44
8381ac5
a2e7544
cadd6c6
b02ea22
a0a2411
fc63db8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,8 @@ | |
| deleting, publishing and unpublishing Firebase ML Kit models. | ||
| """ | ||
|
|
||
| import datetime | ||
| import numbers | ||
| import re | ||
| import requests | ||
| import six | ||
|
|
@@ -63,9 +65,25 @@ def delete_model(model_id, app=None): | |
|
|
||
| class Model(object): | ||
| """A Firebase ML Kit Model object.""" | ||
| def __init__(self, data): | ||
| def __init__(self, data=None, display_name=None, tags=None, model_format=None): | ||
hiranya911 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Created from a data dictionary.""" | ||
| self._data = data | ||
| if data is not None and isinstance(data, dict): | ||
| self._data = data | ||
| else: | ||
| self._data = {} | ||
| if display_name is not None: | ||
| _validate_display_name(display_name) | ||
|
||
| self._data['displayName'] = display_name | ||
| if tags is not None: | ||
| _validate_tags(tags) | ||
| self._data['tags'] = tags | ||
| if model_format is not None: | ||
| _validate_model_format(model_format) | ||
| if isinstance(model_format, TFLiteFormat): | ||
| self._data['tfliteModel'] = model_format.get_json() | ||
| else: | ||
| raise TypeError('Unsupported model format type.') | ||
|
|
||
|
|
||
| def __eq__(self, other): | ||
| if isinstance(other, self.__class__): | ||
|
|
@@ -77,15 +95,181 @@ def __ne__(self, other): | |
| return not self.__eq__(other) | ||
|
|
||
| @property | ||
| def name(self): | ||
| return self._data['name'] | ||
| def model_id(self): | ||
| if not self._data.get('name'): | ||
| return None | ||
| _, model_id = _validate_and_parse_name(self._data.get('name')) | ||
| return model_id | ||
|
|
||
| @property | ||
| def display_name(self): | ||
| return self._data['displayName'] | ||
| return self._data.get('displayName') | ||
|
|
||
| @display_name.setter | ||
| def display_name(self, display_name): | ||
| _validate_display_name(display_name) | ||
| self._data['displayName'] = display_name | ||
| return self | ||
|
|
||
| @property | ||
| def create_time(self): | ||
| if self._data.get('createTime') and \ | ||
hiranya911 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self._data.get('createTime').get('seconds') and \ | ||
| isinstance(self._data.get('createTime').get('seconds'), numbers.Number): | ||
| return datetime.datetime.fromtimestamp( | ||
| float(self._data.get('createTime').get('seconds'))) | ||
| return None | ||
|
|
||
| @property | ||
| def update_time(self): | ||
| if self._data.get('updateTime') and \ | ||
| self._data.get('updateTime').get('seconds') and \ | ||
| isinstance(self._data.get('updateTime').get('seconds'), numbers.Number): | ||
| return datetime.datetime.fromtimestamp( | ||
| float(self._data.get('updateTime').get('seconds'))) | ||
| return None | ||
|
|
||
| @property | ||
| def validation_error(self): | ||
| return self._data.get('state') and \ | ||
| self._data.get('state').get('validationError') and \ | ||
| self._data.get('state').get('validationError').get('message') | ||
hiranya911 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @property | ||
| def published(self): | ||
| return bool(self._data.get('state') and | ||
hiranya911 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self._data.get('state').get('published')) | ||
|
|
||
| @property | ||
| def etag(self): | ||
| return self._data.get('etag') | ||
|
|
||
| @property | ||
| def model_hash(self): | ||
| return self._data.get('modelHash') | ||
|
|
||
| @property | ||
| def tags(self): | ||
| return self._data.get('tags') | ||
|
|
||
| @tags.setter | ||
| def tags(self, tags): | ||
| _validate_tags(tags) | ||
| self._data['tags'] = tags | ||
| return self | ||
|
|
||
| @property | ||
| def locked(self): | ||
| return bool(self._data.get('activeOperations') and | ||
| len(self._data.get('activeOperations')) > 0) | ||
|
|
||
| @property | ||
| def model_format(self): | ||
| if self._data.get('tfliteModel'): | ||
| return TFLiteFormat(self._data.get('tfliteModel')) | ||
| return None | ||
|
|
||
| @model_format.setter | ||
| def model_format(self, model_format): | ||
| if not isinstance(model_format, TFLiteFormat): | ||
| raise TypeError('Unsupported model format type.') | ||
| self._data['tfliteModel'] = model_format.get_json() | ||
| return self | ||
|
|
||
| def get_json(self): | ||
hiranya911 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return self._data | ||
hiranya911 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| class ModelFormat(object): | ||
| """Abstract base class representing a Model Format such as TFLite.""" | ||
| def get_json(self): | ||
hiranya911 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| raise NotImplementedError | ||
|
|
||
|
|
||
| class TFLiteFormat(ModelFormat): | ||
| """Model format representing a TFLite model.""" | ||
hiranya911 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def __init__(self, data=None, model_source=None): | ||
hiranya911 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if (data is not None) and isinstance(data, dict): | ||
hiranya911 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self._data = data | ||
| else: | ||
| self._data = {} | ||
| if model_source is not None: | ||
| # Check for correct base type | ||
| if not isinstance(model_source, TFLiteModelSource): | ||
| raise TypeError('Model source must be a ModelSource object.') | ||
| # Set based on specific sub type | ||
hiranya911 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if isinstance(model_source, TFLiteGCSModelSource): | ||
| self._data['gcsTfliteUri'] = model_source.get_json() | ||
| else: | ||
| raise TypeError('Unsupported model source type.') | ||
|
|
||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def __eq__(self, other): | ||
| if isinstance(other, self.__class__): | ||
| return self._data == other._data # pylint: disable=protected-access | ||
| else: | ||
| return False | ||
|
|
||
| def __ne__(self, other): | ||
| return not self.__eq__(other) | ||
|
|
||
| @property | ||
| def model_source(self): | ||
| if self._data.get('gcsTfliteUri'): | ||
| return TFLiteGCSModelSource(self._data.get('gcsTfliteUri')) | ||
| return None | ||
|
|
||
| @model_source.setter | ||
| def model_source(self, model_source): | ||
| if model_source is not None: | ||
| if isinstance(model_source, TFLiteGCSModelSource): | ||
| self._data['gcsTfliteUri'] = model_source.get_json() | ||
| else: | ||
| raise TypeError('Unsupported model source type.') | ||
|
|
||
|
|
||
| @property | ||
| def size_bytes(self): | ||
| return self._data.get('sizeBytes') | ||
|
|
||
| def get_json(self): | ||
| return self._data | ||
|
|
||
|
|
||
| class TFLiteModelSource(object): | ||
| """Abstract base class representing a model source for TFLite format models.""" | ||
| def get_json(self): | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| class TFLiteGCSModelSource(TFLiteModelSource): | ||
| """TFLite model source representing a tflite model file stored in GCS.""" | ||
| def __init__(self, gcs_tflite_uri): | ||
| _validate_gcs_tflite_uri(gcs_tflite_uri) | ||
| self._gcs_tflite_uri = gcs_tflite_uri | ||
|
|
||
| #TODO(ifielker): define the rest of the Model properties etc | ||
| def __eq__(self, other): | ||
| if isinstance(other, self.__class__): | ||
| return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access | ||
| else: | ||
| return False | ||
|
|
||
| def __ne__(self, other): | ||
| return not self.__eq__(other) | ||
|
|
||
| @property | ||
| def gcs_tflite_uri(self): | ||
| return self._gcs_tflite_uri | ||
|
|
||
| @gcs_tflite_uri.setter | ||
| def gcs_tflite_uri(self, gcs_tflite_uri): | ||
| _validate_gcs_tflite_uri(gcs_tflite_uri) | ||
| self._gcs_tflite_uri = gcs_tflite_uri | ||
|
|
||
| def get_json(self): | ||
| return self._gcs_tflite_uri | ||
hiranya911 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| #TODO(ifielker): implement from_saved_model etc. | ||
|
|
||
| class ListModelsPage(object): | ||
| """Represents a page of models in a firebase project. | ||
|
|
@@ -179,13 +363,55 @@ def __iter__(self): | |
| return self | ||
|
|
||
|
|
||
| def _validate_and_parse_name(name): | ||
| # The resource name is added automatically from API call responses. | ||
| # The only way it could be invalid is if someone tries to | ||
| # create a model from a dictionary manually and does it incorrectly. | ||
| if not isinstance(name, six.string_types): | ||
| raise TypeError('Model resource name must be a string.') | ||
| matcher = re.match( | ||
| r'^projects/(?P<project_id>[^/]+)/models/(?P<model_id>[A-Za-z0-9_-]{1,60})$', | ||
| name) | ||
| if not matcher: | ||
| raise ValueError('Model resource name format is invalid.') | ||
| return matcher.group('project_id'), matcher.group('model_id') | ||
|
|
||
|
|
||
| def _validate_model_id(model_id): | ||
| if not isinstance(model_id, six.string_types): | ||
| raise TypeError('Model ID must be a string.') | ||
| if not re.match(r'^[A-Za-z0-9_-]{1,60}$', model_id): | ||
| raise ValueError('Model ID format is invalid.') | ||
|
|
||
|
|
||
| def _validate_display_name(display_name): | ||
| if not isinstance(display_name, six.string_types): | ||
hiranya911 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| raise TypeError('Display name must be a string.') | ||
| if not re.match(r'^[A-Za-z0-9_-]{1,60}$', display_name): | ||
hiranya911 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| raise ValueError('Display name format is invalid.') | ||
|
|
||
|
|
||
| def _validate_tags(tags): | ||
| if not isinstance(tags, list) or not \ | ||
hiranya911 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| all(isinstance(tag, six.string_types) for tag in tags): | ||
| raise TypeError('Tags must be a list of strings.') | ||
| if not all(re.match(r'^[A-Za-z0-9_-]{1,60}$', tag) for tag in tags): | ||
| raise ValueError('Tag format is invalid.') | ||
|
|
||
|
|
||
| def _validate_gcs_tflite_uri(uri): | ||
| if not isinstance(uri, six.string_types): | ||
hiranya911 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| raise TypeError('Gcs TFLite URI must be a string.') | ||
| # GCS Bucket naming rules are complex. The regex is not comprehensive. | ||
| # See https://cloud.google.com/storage/docs/naming for full details. | ||
| if not re.match(r'^gs://[a-z0-9_.-]{3,63}/.+', uri): | ||
| raise ValueError('GCS TFLite URI format is invalid.') | ||
|
|
||
| def _validate_model_format(model_format): | ||
| if model_format is not None: | ||
hiranya911 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if not isinstance(model_format, ModelFormat): | ||
| raise TypeError('Model format must be a ModelFormat object.') | ||
|
|
||
| def _validate_list_filter(list_filter): | ||
| if list_filter is not None: | ||
| if not isinstance(list_filter, six.string_types): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.