diff --git a/airflow/contrib/hooks/__init__.py b/airflow/contrib/hooks/__init__.py index 182a49f1756d4..4941314790a91 100644 --- a/airflow/contrib/hooks/__init__.py +++ b/airflow/contrib/hooks/__init__.py @@ -40,6 +40,7 @@ 'qubole_hook': ['QuboleHook'], 'gcs_hook': ['GoogleCloudStorageHook'], 'datastore_hook': ['DatastoreHook'], + 'gcp_cloudml_hook': ['CloudMLHook'], 'gcp_dataproc_hook': ['DataProcHook'], 'gcp_dataflow_hook': ['DataFlowHook'], 'spark_submit_operator': ['SparkSubmitOperator'], diff --git a/airflow/contrib/hooks/gcp_cloudml_hook.py b/airflow/contrib/hooks/gcp_cloudml_hook.py new file mode 100644 index 0000000000000..e722b2acb296d --- /dev/null +++ b/airflow/contrib/hooks/gcp_cloudml_hook.py @@ -0,0 +1,167 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import random +import time +from airflow import settings +from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook +from apiclient.discovery import build +from apiclient import errors +from oauth2client.client import GoogleCredentials + +logging.getLogger('GoogleCloudML').setLevel(settings.LOGGING_LEVEL) + + +def _poll_with_exponential_delay(request, max_n, is_done_func, is_error_func): + + for i in range(0, max_n): + try: + response = request.execute() + if is_error_func(response): + raise ValueError('The response contained an error: {}'.format(response)) + elif is_done_func(response): + logging.info('Operation is done: {}'.format(response)) + return response + else: + time.sleep((2**i) + (random.randint(0, 1000) / 1000)) + except errors.HttpError as e: + if e.resp.status != 429: + logging.info('Something went wrong. Not retrying: {}'.format(e)) + raise e + else: + time.sleep((2**i) + (random.randint(0, 1000) / 1000)) + + +class CloudMLHook(GoogleCloudBaseHook): + + def __init__(self, gcp_conn_id='google_cloud_default', delegate_to=None): + super(CloudMLHook, self).__init__(gcp_conn_id, delegate_to) + self._cloudml = self.get_conn() + + def get_conn(self): + """ + Returns a Google CloudML service object. + """ + credentials = GoogleCredentials.get_application_default() + return build('ml', 'v1', credentials=credentials) + + def create_version(self, project_name, model_name, version_spec): + """ + Creates the Version on Cloud ML. + + Returns the operation if the version was created successfully and raises + an error otherwise. + """ + parent_name = 'projects/{}/models/{}'.format(project_name, model_name) + create_request = self._cloudml.projects().models().versions().create( + parent=parent_name, body=version_spec) + response = create_request.execute() + get_request = self._cloudml.projects().operations().get( + name=response['name']) + + return _poll_with_exponential_delay( + request=get_request, + max_n=9, + is_done_func=lambda resp: resp.get('done', False), + is_error_func=lambda resp: resp.get('error', None) is not None) + + def set_default_version(self, project_name, model_name, version_name): + """ + Sets a version to be the default. Blocks until finished. + """ + full_version_name = 'projects/{}/models/{}/versions/{}'.format( + project_name, model_name, version_name) + request = self._cloudml.projects().models().versions().setDefault( + name=full_version_name, body={}) + + try: + response = request.execute() + logging.info('Successfully set version: {} to default'.format(response)) + return response + except errors.HttpError as e: + logging.error('Something went wrong: {}'.format(e)) + raise e + + def list_versions(self, project_name, model_name): + """ + Lists all available versions of a model. Blocks until finished. + """ + result = [] + full_parent_name = 'projects/{}/models/{}'.format( + project_name, model_name) + request = self._cloudml.projects().models().versions().list( + parent=full_parent_name, pageSize=100) + + response = request.execute() + next_page_token = response.get('nextPageToken', None) + result.extend(response.get('versions', [])) + while next_page_token is not None: + next_request = self._cloudml.projects().models().versions().list( + parent=full_parent_name, + pageToken=next_page_token, + pageSize=100) + response = next_request.execute() + next_page_token = response.get('nextPageToken', None) + result.extend(response.get('versions', [])) + time.sleep(5) + return result + + def delete_version(self, project_name, model_name, version_name): + """ + Deletes the given version of a model. Blocks until finished. + """ + full_name = 'projects/{}/models/{}/versions/{}'.format( + project_name, model_name, version_name) + delete_request = self._cloudml.projects().models().versions().delete( + name=full_name) + response = delete_request.execute() + get_request = self._cloudml.projects().operations().get( + name=response['name']) + + return _poll_with_exponential_delay( + request=get_request, + max_n=9, + is_done_func=lambda resp: resp.get('done', False), + is_error_func=lambda resp: resp.get('error', None) is not None) + + def create_model(self, project_name, model): + """ + Create a Model. Blocks until finished. + """ + assert model['name'] is not None and model['name'] is not '' + project = 'projects/{}'.format(project_name) + + request = self._cloudml.projects().models().create( + parent=project, body=model) + return request.execute() + + def get_model(self, project_name, model_name): + """ + Gets a Model. Blocks until finished. + """ + assert model_name is not None and model_name is not '' + full_model_name = 'projects/{}/models/{}'.format( + project_name, model_name) + request = self._cloudml.projects().models().get(name=full_model_name) + try: + return request.execute() + except errors.HttpError as e: + if e.resp.status == 404: + logging.error('Model was not found: {}'.format(e)) + return None + raise e diff --git a/airflow/contrib/operators/cloudml_operator.py b/airflow/contrib/operators/cloudml_operator.py new file mode 100644 index 0000000000000..b0b6e91a17b83 --- /dev/null +++ b/airflow/contrib/operators/cloudml_operator.py @@ -0,0 +1,178 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the 'License'); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from airflow import settings +from airflow.contrib.hooks.gcp_cloudml_hook import CloudMLHook +from airflow.operators import BaseOperator +from airflow.utils.decorators import apply_defaults + +logging.getLogger('GoogleCloudML').setLevel(settings.LOGGING_LEVEL) + + +class CloudMLVersionOperator(BaseOperator): + """ + Operator for managing a Google Cloud ML version. + + :param model_name: The name of the Google Cloud ML model that the version + belongs to. + :type model_name: string + + :param project_name: The Google Cloud project name to which CloudML + model belongs. + :type project_name: string + + :param version: A dictionary containing the information about the version. + If the `operation` is `create`, `version` should contain all the + information about this version such as name, and deploymentUrl. + If the `operation` is `get` or `delete`, the `version` parameter + should contain the `name` of the version. + If it is None, the only `operation` possible would be `list`. + :type version: dict + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: string + + :param operation: The operation to perform. Available operations are: + 'create': Creates a new version in the model specified by `model_name`, + in which case the `version` parameter should contain all the + information to create that version + (e.g. `name`, `deploymentUrl`). + 'get': Gets full information of a particular version in the model + specified by `model_name`. + The name of the version should be specified in the `version` + parameter. + + 'list': Lists all available versions of the model specified + by `model_name`. + + 'delete': Deletes the version specified in `version` parameter from the + model specified by `model_name`). + The name of the version should be specified in the `version` + parameter. + :type operation: string + + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: string + """ + + + template_fields = [ + '_model_name', + '_version', + ] + + @apply_defaults + def __init__(self, + model_name, + project_name, + version=None, + gcp_conn_id='google_cloud_default', + operation='create', + delegate_to=None, + *args, + **kwargs): + + super(CloudMLVersionOperator, self).__init__(*args, **kwargs) + self._model_name = model_name + self._version = version + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + self._project_name = project_name + self._operation = operation + + def execute(self, context): + hook = CloudMLHook( + gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to) + + if self._operation == 'create': + assert self._version is not None + return hook.create_version(self._project_name, self._model_name, + self._version) + elif self._operation == 'set_default': + return hook.set_default_version( + self._project_name, self._model_name, + self._version['name']) + elif self._operation == 'list': + return hook.list_versions(self._project_name, self._model_name) + elif self._operation == 'delete': + return hook.delete_version(self._project_name, self._model_name, + self._version['name']) + else: + raise ValueError('Unknown operation: {}'.format(self._operation)) + + +class CloudMLModelOperator(BaseOperator): + """ + Operator for managing a Google Cloud ML model. + + :param model: A dictionary containing the information about the model. + If the `operation` is `create`, then the `model` parameter should + contain all the information about this model such as `name`. + + If the `operation` is `get`, the `model` parameter + should contain the `name` of the model. + :type model: dict + + :param project_name: The Google Cloud project name to which CloudML + model belongs. + :type project_name: string + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: string + + :param operation: The operation to perform. Available operations are: + 'create': Creates a new model as provided by the `model` parameter. + 'get': Gets a particular model where the name is specified in `model`. + + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: string + """ + + template_fields = [ + '_model', + ] + + @apply_defaults + def __init__(self, + model, + project_name, + gcp_conn_id='google_cloud_default', + operation='create', + delegate_to=None, + *args, + **kwargs): + super(CloudMLModelOperator, self).__init__(*args, **kwargs) + self._model = model + self._operation = operation + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + self._project_name = project_name + + def execute(self, context): + hook = CloudMLHook( + gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to) + if self._operation == 'create': + hook.create_model(self._project_name, self._model) + elif self._operation == 'get': + hook.get_model(self._project_name, self._model['name']) + else: + raise ValueError('Unknown operation: {}'.format(self._operation)) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 54254f61ddbdb..04b151232f9de 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -128,6 +128,10 @@ def initdb(): conn_id='presto_default', conn_type='presto', host='localhost', schema='hive', port=3400)) + merge_conn( + models.Connection( + conn_id='google_cloud_default', conn_type='google_cloud_platform', + schema='default',)) merge_conn( models.Connection( conn_id='hive_cli_default', conn_type='hive_cli', diff --git a/tests/contrib/hooks/test_gcp_cloudml_hook.py b/tests/contrib/hooks/test_gcp_cloudml_hook.py new file mode 100644 index 0000000000000..aa50e69f14aef --- /dev/null +++ b/tests/contrib/hooks/test_gcp_cloudml_hook.py @@ -0,0 +1,255 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import mock +import unittest +try: # python 2 + from urlparse import urlparse, parse_qsl +except ImportError: #python 3 + from urllib.parse import urlparse, parse_qsl + +from airflow.contrib.hooks import gcp_cloudml_hook as hook +from apiclient.discovery import build +from apiclient.http import HttpMockSequence +from oauth2client.contrib.gce import HttpAccessTokenRefreshError + +cml_available = True +try: + hook.CloudMLHook().get_conn() +except HttpAccessTokenRefreshError: + cml_available = False + + +class _TestCloudMLHook(object): + + def __init__(self, test_cls, responses, expected_requests): + """ + Init method. + + Usage example: + with _TestCloudMLHook(self, responses, expected_requests) as hook: + self.run_my_test(hook) + + Args: + test_cls: The caller's instance used for test communication. + responses: A list of (dict_response, response_content) tuples. + expected_requests: A list of (uri, http_method, body) tuples. + """ + + self._test_cls = test_cls + self._responses = responses + self._expected_requests = [ + self._normalize_requests_for_comparison(x[0], x[1], x[2]) for x in expected_requests] + self._actual_requests = [] + + def _normalize_requests_for_comparison(self, uri, http_method, body): + parts = urlparse(uri) + return (parts._replace(query=set(parse_qsl(parts.query))), http_method, body) + + def __enter__(self): + http = HttpMockSequence(self._responses) + native_request_method = http.request + + # Collecting requests to validate at __exit__. + def _request_wrapper(*args, **kwargs): + self._actual_requests.append(args + (kwargs['body'],)) + return native_request_method(*args, **kwargs) + + http.request = _request_wrapper + service_mock = build('ml', 'v1', http=http) + with mock.patch.object( + hook.CloudMLHook, 'get_conn', return_value=service_mock): + return hook.CloudMLHook() + + def __exit__(self, *args): + # Propogating exceptions here since assert will silence them. + if any(args): + return None + self._test_cls.assertEquals( + [self._normalize_requests_for_comparison(x[0], x[1], x[2]) for x in self._actual_requests], self._expected_requests) + + +class TestCloudMLHook(unittest.TestCase): + + def setUp(self): + pass + + _SKIP_IF = unittest.skipIf(not cml_available, + 'CloudML is not available to run tests') + _SERVICE_URI_PREFIX = 'https://ml.googleapis.com/v1/' + + @_SKIP_IF + def test_create_version(self): + project = 'test-project' + model_name = 'test-model' + version = 'test-version' + operation_name = 'projects/{}/operations/test-operation'.format( + project) + + response_body = {'name': operation_name, 'done': True} + succeeded_response = ({'status': '200'}, json.dumps(response_body)) + + expected_requests = [ + ('{}projects/{}/models/{}/versions?alt=json'.format( + self._SERVICE_URI_PREFIX, project, model_name), 'POST', + '"{}"'.format(version)), + ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name), + 'GET', None), + ] + + with _TestCloudMLHook( + self, + responses=[succeeded_response] * 2, + expected_requests=expected_requests) as cml_hook: + create_version_response = cml_hook.create_version( + project_name=project, model_name=model_name, version_spec=version) + self.assertEquals(create_version_response, response_body) + + @_SKIP_IF + def test_set_default_version(self): + project = 'test-project' + model_name = 'test-model' + version = 'test-version' + operation_name = 'projects/{}/operations/test-operation'.format( + project) + + response_body = {'name': operation_name, 'done': True} + succeeded_response = ({'status': '200'}, json.dumps(response_body)) + + expected_requests = [ + ('{}projects/{}/models/{}/versions/{}:setDefault?alt=json'.format( + self._SERVICE_URI_PREFIX, project, model_name, version), 'POST', + '{}'), + ] + + with _TestCloudMLHook( + self, + responses=[succeeded_response], + expected_requests=expected_requests) as cml_hook: + set_default_version_response = cml_hook.set_default_version( + project_name=project, model_name=model_name, version_name=version) + self.assertEquals(set_default_version_response, response_body) + + @_SKIP_IF + def test_list_versions(self): + project = 'test-project' + model_name = 'test-model' + operation_name = 'projects/{}/operations/test-operation'.format( + project) + + # This test returns the versions one at a time. + versions = ['ver_{}'.format(ix) for ix in range(3)] + + response_bodies = [{'name': operation_name, 'nextPageToken': ix, 'versions': [ + ver]} for ix, ver in enumerate(versions)] + response_bodies[-1].pop('nextPageToken') + responses = [({'status': '200'}, json.dumps(body)) + for body in response_bodies] + + expected_requests = [ + ('{}projects/{}/models/{}/versions?alt=json&pageSize=100'.format( + self._SERVICE_URI_PREFIX, project, model_name), 'GET', + None), + ] + [ + ('{}projects/{}/models/{}/versions?alt=json&pageToken={}&pageSize=100'.format( + self._SERVICE_URI_PREFIX, project, model_name, ix), 'GET', + None) for ix in range(len(versions) - 1) + ] + + with _TestCloudMLHook( + self, + responses=responses, + expected_requests=expected_requests) as cml_hook: + list_versions_response = cml_hook.list_versions( + project_name=project, model_name=model_name) + self.assertEquals(list_versions_response, versions) + + @_SKIP_IF + def test_delete_version(self): + project = 'test-project' + model_name = 'test-model' + version = 'test-version' + operation_name = 'projects/{}/operations/test-operation'.format( + project) + + not_done_response_body = {'name': operation_name, 'done': False} + done_response_body = {'name': operation_name, 'done': True} + not_done_response = ( + {'status': '200'}, json.dumps(not_done_response_body)) + succeeded_response = ( + {'status': '200'}, json.dumps(done_response_body)) + + expected_requests = [ + ('{}projects/{}/models/{}/versions/{}?alt=json'.format( + self._SERVICE_URI_PREFIX, project, model_name, version), 'DELETE', + None), + ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name), + 'GET', None), + ] + + with _TestCloudMLHook( + self, + responses=[not_done_response, succeeded_response], + expected_requests=expected_requests) as cml_hook: + delete_version_response = cml_hook.delete_version( + project_name=project, model_name=model_name, version_name=version) + self.assertEquals(delete_version_response, done_response_body) + + @_SKIP_IF + def test_create_model(self): + project = 'test-project' + model_name = 'test-model' + model = { + 'name': model_name, + } + response_body = {} + succeeded_response = ({'status': '200'}, json.dumps(response_body)) + + expected_requests = [ + ('{}projects/{}/models?alt=json'.format( + self._SERVICE_URI_PREFIX, project), 'POST', + json.dumps(model)), + ] + + with _TestCloudMLHook( + self, + responses=[succeeded_response], + expected_requests=expected_requests) as cml_hook: + create_model_response = cml_hook.create_model( + project_name=project, model=model) + self.assertEquals(create_model_response, response_body) + + @_SKIP_IF + def test_get_model(self): + project = 'test-project' + model_name = 'test-model' + response_body = {'model': model_name} + succeeded_response = ({'status': '200'}, json.dumps(response_body)) + + expected_requests = [ + ('{}projects/{}/models/{}?alt=json'.format( + self._SERVICE_URI_PREFIX, project, model_name), 'GET', + None), + ] + + with _TestCloudMLHook( + self, + responses=[succeeded_response], + expected_requests=expected_requests) as cml_hook: + get_model_response = cml_hook.get_model( + project_name=project, model_name=model_name) + self.assertEquals(get_model_response, response_body) + + +if __name__ == '__main__': + unittest.main()