From 943047581034d0974ac8f668ae7fb49843030e5c Mon Sep 17 00:00:00 2001 From: Andrew Scribner Date: Mon, 4 Apr 2022 09:19:36 -0400 Subject: [PATCH] feat: check and auto-create S3 artifact bucket if missing (#34) * feat: check and auto-create S3 artifact bucket if missing Adds the feature that: * Checks related S3 storage for the bucket named by the default_artifact_root config option * If create_artifact_root_if_not_exists==True and default bucket does not exist, attempts to create the bucket Supporting this is an S3 bucket wrapper to help checking the bucket existance/creating new buckets and unit/integration tests. Also included here was some refactoring to make the main function of the Operator read a bit clearer, packaging logic into some verbosely named helpers. Closes #23 * feat: add more unit tests for bucket creation Also refactors/fixes bucket creation logic * fix: typos in services/s3.py, related tests * fix: linting/fmt * refactor: move test_s3.py to unit test folder * feat: add integration test for automatic bucket creation * fix: errors in test_default_bucket_created * fix: linting * fix: integration test for github actions Using AWS client from some virtualized environments prevents the client from inferring the region, resulting in an error. This change sets the region explicitly. * fix: formatting * fix: formatting * Revert "fix: formatting" This reverts commit fbee1a8963621152e305b5625be40ff3dd17c2e7. * fix: formatting * fix: formatting. again... --- charms/mlflow-server/config.yaml | 6 + charms/mlflow-server/requirements.txt | 1 + charms/mlflow-server/src/charm.py | 159 +++++++++----- charms/mlflow-server/src/services/s3.py | 83 +++++++ charms/mlflow-server/test-requirements.txt | 2 + .../tests/integration/test_charm.py | 88 +++++++- .../mlflow-server/tests/unit/test_operator.py | 205 +++++++++++++++--- charms/mlflow-server/tests/unit/test_s3.py | 120 ++++++++++ 8 files changed, 581 insertions(+), 83 deletions(-) create mode 100644 charms/mlflow-server/src/services/s3.py create mode 100644 charms/mlflow-server/tests/unit/test_s3.py diff --git a/charms/mlflow-server/config.yaml b/charms/mlflow-server/config.yaml index 3f604f09..600faac0 100644 --- a/charms/mlflow-server/config.yaml +++ b/charms/mlflow-server/config.yaml @@ -2,6 +2,12 @@ # See LICENSE file for licensing details. # options: + create_default_artifact_root_if_missing: + description: | + If True, charm will try to create the default_artifact_root bucket in S3 if it does not + exist. If False and the bucket does not exist, the charm enter Blocked status + type: boolean + default: true default_artifact_root: description: | The name of the default bucket mlflow uses for artifacts, if not specified by the workflow diff --git a/charms/mlflow-server/requirements.txt b/charms/mlflow-server/requirements.txt index cfa82a75..e9c406bd 100644 --- a/charms/mlflow-server/requirements.txt +++ b/charms/mlflow-server/requirements.txt @@ -1,3 +1,4 @@ +boto3 ops==1.2.0 oci-image==1.0.0 ops-lib-mysql diff --git a/charms/mlflow-server/src/charm.py b/charms/mlflow-server/src/charm.py index 586f01b8..f57b4ca4 100755 --- a/charms/mlflow-server/src/charm.py +++ b/charms/mlflow-server/src/charm.py @@ -9,7 +9,6 @@ import json import logging -import re from base64 import b64encode from oci_image import OCIImageResource, OCIImageResourceError @@ -28,7 +27,7 @@ get_interfaces, ) -DB_NAME = "mlflow" +from services.s3 import S3BucketWrapper, validate_s3_bucket_name class Operator(CharmBase): @@ -42,6 +41,7 @@ def __init__(self, *args): self.image = OCIImageResource(self, "oci-image") self.log = logging.getLogger(__name__) + self.charm_name = self.model.app.name for event in [ self.on.install, @@ -94,52 +94,27 @@ def main(self, event): Runs at install, update, config change and relation change. """ try: + self.model.unit.status = MaintenanceStatus("Validating inputs and computing pod spec") + self._check_leader() - default_artifact_root = validate_s3_bucket_name(self.config["default_artifact_root"]) interfaces = self._get_interfaces() image_details = self._check_image_details() - except CheckFailedError as check_failed: - self.model.unit.status = check_failed.status - self.model.unit.message = check_failed.msg - return - self._configure_mesh(interfaces) - config = self.model.config - charm_name = self.model.app.name + mysql = self._configure_mysql() + obj_storage = _get_obj_storage(interfaces) + secrets = self._define_secrets(obj_storage=obj_storage, mysql=mysql) - mysql = self.model.relations["db"] - if len(mysql) > 1: - self.model.unit.status = BlockedStatus("Too many mysql relations") - return - - try: - mysql = mysql[0] - unit = list(mysql.units)[0] - mysql = mysql.data[unit] - mysql["database"] - except (IndexError, KeyError): - self.model.unit.status = WaitingStatus("Waiting for mysql relation data") - return + default_artifact_root = self._validate_default_s3_bucket(obj_storage) - if not ((obj_storage := interfaces["object-storage"]) and obj_storage.get_data()): - self.model.unit.status = WaitingStatus("Waiting for object-storage relation data") + self._configure_mesh(interfaces) + except CheckFailedError as check_failed: + self.model.unit.status = check_failed.status + self.model.unit.message = check_failed.msg return self.model.unit.status = MaintenanceStatus("Setting pod spec") - obj_storage = list(obj_storage.get_data().values())[0] - secrets = [ - { - "name": f"{charm_name}-minio-secret", - "data": _minio_credentials_dict(obj_storage=obj_storage), - }, - { - "name": f"{charm_name}-seldon-init-container-s3-credentials", - "data": _seldon_credentials_dict(obj_storage=obj_storage), - }, - {"name": f"{charm_name}-db-secret", "data": _db_secret_dict(mysql=mysql)}, - ] - + config = self.model.config self.model.pod.set_spec( { "version": 3, @@ -157,8 +132,8 @@ def main(self, event): f"s3://{default_artifact_root}/", ], "envConfig": { - "db-secret": {"secret": {"name": f"{charm_name}-db-secret"}}, - "aws-secret": {"secret": {"name": f"{charm_name}-minio-secret"}}, + "db-secret": {"secret": {"name": f"{self.charm_name}-db-secret"}}, + "aws-secret": {"secret": {"name": f"{self.charm_name}-minio-secret"}}, "AWS_DEFAULT_REGION": "us-east-1", "MLFLOW_S3_ENDPOINT_URL": "http://{service}.{namespace}:{port}".format( **obj_storage @@ -236,6 +211,22 @@ def _configure_mesh(self, interfaces): } ) + def _configure_mysql( + self, + ): + mysql = self.model.relations["db"] + if len(mysql) > 1: + raise CheckFailedError("Too many mysql relations", BlockedStatus) + + try: + mysql = mysql[0] + unit = list(mysql.units)[0] + mysql = mysql.data[unit] + mysql["database"] + return mysql + except (IndexError, KeyError): + raise CheckFailedError("Waiting for mysql relation data", WaitingStatus) + def _check_leader(self): if not self.unit.is_leader(): # We can't do anything useful when not the leader, so do nothing. @@ -257,21 +248,59 @@ def _check_image_details(self): raise CheckFailedError(f"{e.status.message}", e.status_type) return image_details - -def validate_s3_bucket_name(name): - """Validates the name as a valid S3 bucket name, raising a CheckFailedError if invalid.""" - # regex from https://stackoverflow.com/a/50484916/5394584 - if re.match( - r"(?=^.{3,63}$)(?!^(\d+\.)+\d+$)(^(([a-z0-9]|[a-z0-9][a-z0-9\-]*[a-z0-9])\.)*([a-z0-9]|[a-z0-9][a-z0-9\-]*[a-z0-9])$)", - name, - ): - return name - else: - msg = ( - f"Invalid value for config default_artifact_root '{name}'" - f" - value must be a valid S3 bucket name" + def _validate_default_s3_bucket(self, obj_storage): + """Validates the default S3 store, ensuring bucket is accessible and creating if needed.""" + # Validate the bucket name + bucket_name = self.config["default_artifact_root"] + if not validate_s3_bucket_name(bucket_name): + msg = ( + f"Invalid value for config default_artifact_root '{bucket_name}'" + f" - value must be a valid S3 bucket name" + ) + raise CheckFailedError(msg, BlockedStatus) + + # Ensure the bucket exists, creating it if missing and create_root_if_not_exists==True + s3_wrapper = S3BucketWrapper( + access_key=obj_storage["access-key"], + secret_access_key=obj_storage["secret-key"], + s3_service=obj_storage["service"], + s3_port=obj_storage["port"], ) - raise CheckFailedError(msg, BlockedStatus) + + if s3_wrapper.check_if_bucket_accessible(bucket_name): + return bucket_name + else: + if self.config["create_default_artifact_root_if_missing"]: + try: + s3_wrapper.create_bucket(bucket_name) + return bucket_name + except Exception as e: + raise CheckFailedError( + "Error with default S3 artifact store - bucket not accessible or " + f"cannot be created. Caught error: '{str(e)}", + BlockedStatus, + ) + else: + raise CheckFailedError( + "Error with default S3 artifact store - bucket not accessible or does not exist." + " Set create_default_artifact_root_if_missing=True to automatically create a " + "missing default bucket", + BlockedStatus, + ) + + def _define_secrets(self, obj_storage, mysql): + """Returns needed secrets in pod_spec.kubernetesResources.secrets format.""" + return [ + { + "name": f"{self.charm_name}-minio-secret", + "data": _minio_credentials_dict(obj_storage=obj_storage), + }, + { + "name": f"{self.charm_name}-seldon-init-container-s3-credentials", + "data": _seldon_credentials_dict(obj_storage=obj_storage), + }, + {"name": f"{self.charm_name}-db-secret", "data": _db_secret_dict(mysql=mysql)}, + ] class CheckFailedError(Exception): @@ -294,7 +323,7 @@ def _b64_encode_dict(d): def _minio_credentials_dict(obj_storage): """Returns a dict of minio credentials with the values base64 encoded.""" minio_credentials = { - "AWS_ENDPOINT_URL": f"http://{obj_storage['service']}.{obj_storage['namespace']}:{obj_storage['port']}", + "AWS_ENDPOINT_URL": f"http://{obj_storage['service']}:{obj_storage['port']}", "AWS_ACCESS_KEY_ID": obj_storage["access-key"], "AWS_SECRET_ACCESS_KEY": obj_storage["secret-key"], "USE_SSL": str(obj_storage["secure"]).lower(), @@ -325,5 +354,25 @@ def _db_secret_dict(mysql): return _b64_encode_dict(db_secret) +def _get_obj_storage(interfaces): + """Unpacks and returns the object-storage relation data. + + Raises CheckFailedError if an anticipated error occurs. + """ + if not ((obj_storage := interfaces["object-storage"]) and obj_storage.get_data()): + raise CheckFailedError("Waiting for object-storage relation data", WaitingStatus) + + try: + obj_storage = list(obj_storage.get_data().values())[0] + except Exception as e: + raise CheckFailedError( + f"Unexpected error unpacking object storage data - data format not " + f"as expected. Caught exception: '{str(e)}'", + BlockedStatus, + ) + + return obj_storage + + if __name__ == "__main__": main(Operator) diff --git a/charms/mlflow-server/src/services/s3.py b/charms/mlflow-server/src/services/s3.py new file mode 100644 index 00000000..9b835378 --- /dev/null +++ b/charms/mlflow-server/src/services/s3.py @@ -0,0 +1,83 @@ +"""Wrapper for basic accessing and validating of S3 Buckets.""" + +import re +from typing import Union + +import boto3 +import botocore.client +import botocore.exceptions + + +class S3BucketWrapper: + """Wrapper for basic accessing and validating of S3 Buckets.""" + + def __init__( + self, access_key: str, secret_access_key: str, s3_service: str, s3_port: Union[str, int] + ): + self.access_key: str = access_key + self.secret_access_key: str = secret_access_key + self.s3_service: str = s3_service + self.s3_port: str = str(s3_port) + + self._client: botocore.client.BaseClient = None + + def check_if_bucket_accessible(self, bucket_name): + """Checks if a bucket exists and is accessible, returning True if both are satisfied. + + Will return False if we encounter a botocore.exceptions.ClientError, which could be + due to the bucket not existing, the client session not having permission to access the + bucket, or some other error with the client. + """ + try: + self.client.head_bucket(Bucket=bucket_name) + return True + except botocore.exceptions.ClientError: + return False + + def create_bucket_if_missing(self, bucket_name): + """Creates the bucket bucket_name if it does not exist, raising an error if it cannot. + + This method tries to access the bucket, assuming that if it is unaccessible that it does + not exist (this is a required assumption as unaccessible buckets look the same as those + that do not exist). If inaccessible, we try to create_bucket and do not catch any + exceptions that result from the call. + """ + if self.check_if_bucket_accessible(bucket_name=bucket_name): + return + + self.create_bucket(bucket_name=bucket_name) + + def create_bucket(self, bucket_name): + """Create a bucket via the client.""" + self.client.create_bucket(Bucket=bucket_name) + + @property + def client(self) -> botocore.client.BaseClient: + """Returns an open boto3 client, creating and caching one if needed.""" + if self._client: + return self._client + else: + self._client = boto3.client( + "s3", + endpoint_url=self.s3_url, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_access_key, + ) + return self._client + + @property + def s3_url(self): + """Returns the S3 url.""" + return f"http://{self.s3_service}:{self.s3_port}" + + +def validate_s3_bucket_name(name): + """Returns True if name is a valid S3 bucket name, else False.""" + # regex from https://stackoverflow.com/a/50484916/5394584 + if re.match( + r"(?=^.{3,63}$)(?!^(\d+\.)+\d+$)(^(([a-z0-9]|[a-z0-9][a-z0-9\-]*[a-z0-9])\.)*([a-z0-9]|[a-z0-9][a-z0-9\-]*[a-z0-9])$)", + name, + ): + return True + else: + return False diff --git a/charms/mlflow-server/test-requirements.txt b/charms/mlflow-server/test-requirements.txt index be9f272b..cd0fd5c8 100644 --- a/charms/mlflow-server/test-requirements.txt +++ b/charms/mlflow-server/test-requirements.txt @@ -1,3 +1,5 @@ black flake8 pytest +pytest-mock +pytest-lazy-fixture diff --git a/charms/mlflow-server/tests/integration/test_charm.py b/charms/mlflow-server/tests/integration/test_charm.py index 1350b431..c93e174b 100644 --- a/charms/mlflow-server/tests/integration/test_charm.py +++ b/charms/mlflow-server/tests/integration/test_charm.py @@ -3,6 +3,8 @@ import logging from pathlib import Path +from random import choices +from string import ascii_lowercase from time import sleep import pytest @@ -23,20 +25,25 @@ METADATA = yaml.safe_load(Path("./metadata.yaml").read_text()) CHARM_NAME = METADATA["name"] +OBJ_STORAGE_NAME = "minio" +OBJ_STORAGE_CONFIG = { + "access-key": "minio", + "secret-key": "minio123", + "port": "9000", +} @pytest.mark.abort_on_fail async def test_build_and_deploy(ops_test: OpsTest): db = "mlflow-db" - obj_storage = "minio" await ops_test.model.deploy("charmed-osm-mariadb-k8s", application_name=db) - await ops_test.model.deploy(obj_storage) + await ops_test.model.deploy(OBJ_STORAGE_NAME, config=OBJ_STORAGE_CONFIG) my_charm = await ops_test.build_charm(".") image_path = METADATA["resources"]["oci-image"]["upstream-source"] resources = {"oci-image": image_path} await ops_test.model.deploy(my_charm, resources=resources) - await ops_test.model.add_relation(CHARM_NAME, obj_storage) + await ops_test.model.add_relation(CHARM_NAME, OBJ_STORAGE_NAME) await ops_test.model.add_relation(CHARM_NAME, db) await ops_test.model.wait_for_idle(status="active") @@ -46,6 +53,81 @@ async def test_successful_deploy(ops_test: OpsTest): assert ops_test.model.applications[CHARM_NAME].units[0].workload_status == "active" +async def test_default_bucket_created(ops_test: OpsTest): + """Tests whether the default bucket is auto-generated by mlflow. + + Note: We do not have a test coverage to assert if that the bucket is not created if + create_default_artifact_root_if_missing==False. + """ + config = await ops_test.model.applications[CHARM_NAME].get_config() + default_bucket_name = config["default_artifact_root"]["value"] + + ret_code, stdout, stderr, kubectl_cmd = await does_minio_bucket_exist( + default_bucket_name, ops_test + ) + assert ret_code == 0, ( + f"Unable to find bucket named {default_bucket_name}, got " + f"stdout=\n'{stdout}\n'stderr=\n{stderr}\nUsed command {kubectl_cmd}" + ) + + +async def does_minio_bucket_exist(bucket_name, ops_test: OpsTest): + """Connects to the minio server and checks if a bucket exists, checking if a bucket exists. + + Returns: + Tuple of the return code, stdout, and stderr + """ + access_key = OBJ_STORAGE_CONFIG["access-key"] + secret_key = OBJ_STORAGE_CONFIG["secret-key"] + port = OBJ_STORAGE_CONFIG["port"] + obj_storage_name = OBJ_STORAGE_NAME + model_name = ops_test.model_name + log.info(f"ops_test.model_name = {ops_test.model_name}") + + obj_storage_url = f"http://{obj_storage_name}.{model_name}.svc.cluster.local:{port}" + + # Region is not used and doesn't matter, but must be set to run in github actions as explained + # in: https://florian.ec/blog/github-actions-awscli-errors/ + aws_cmd = ( + f"aws --endpoint-url {obj_storage_url} --region us-east-1 s3api head-bucket" + f" --bucket={bucket_name}" + ) + + # Add random suffix to pod name to avoid collision + this_pod_name = f"{CHARM_NAME}-minio-bucket-test-{generate_random_string()}" + + kubectl_cmd = ( + "microk8s", + "kubectl", + "run", + "--rm", + "-i", + "--restart=Never", + f"--namespace={ops_test.model_name}", + this_pod_name, + f"--env=AWS_ACCESS_KEY_ID={access_key}", + f"--env=AWS_SECRET_ACCESS_KEY={secret_key}", + "--image=amazon/aws-cli", + "--command", + "--", + "sh", + "-c", + aws_cmd, + ) + + ( + ret_code, + stdout, + stderr, + ) = await ops_test.run(*kubectl_cmd) + return ret_code, stdout, stderr, " ".join(kubectl_cmd) + + +def generate_random_string(length: int = 4): + """Returns a random string of lower case alphabetic characters and given length.""" + return "".join(choices(ascii_lowercase, k=length)) + + @pytest.mark.abort_on_fail async def test_deploy_with_ingress(ops_test: OpsTest): istio_pilot = "istio-pilot" diff --git a/charms/mlflow-server/tests/unit/test_operator.py b/charms/mlflow-server/tests/unit/test_operator.py index 7ba9ccf5..be77c542 100644 --- a/charms/mlflow-server/tests/unit/test_operator.py +++ b/charms/mlflow-server/tests/unit/test_operator.py @@ -3,14 +3,14 @@ import json from base64 import b64decode -from contextlib import nullcontext as does_not_raise +from unittest.mock import MagicMock import pytest import yaml from ops.model import ActiveStatus, BlockedStatus, WaitingStatus from ops.testing import Harness -from charm import CheckFailedError, Operator, validate_s3_bucket_name +from charm import CheckFailedError, Operator @pytest.fixture @@ -47,29 +47,179 @@ def test_main_no_relation(harness): assert harness.charm.model.unit.status == WaitingStatus("Waiting for mysql relation data") -@pytest.mark.parametrize( - "name,context_raised", - [ - # Note, this is a non-exhaustive list - ("some-valid-name", does_not_raise()), - ("0123456789", does_not_raise()), - ("01", pytest.raises(CheckFailedError)), # name too short - ("x" * 64, pytest.raises(CheckFailedError)), # name too long - ("some_invalid_name", pytest.raises(CheckFailedError)), # name has '_' - ("some;invalid;name" * 64, pytest.raises(CheckFailedError)), # name has special characters - ("Some-Invalid-Name", pytest.raises(CheckFailedError)), # name has capitals - ], -) -def test_validate_s3_bucket_name(name, context_raised): - with context_raised as err: - assert name == validate_s3_bucket_name(name) - if isinstance(err, Exception): - error_message = "Invalid value for config default_artifact_root" - assert error_message in str(err) - assert err.status_type == BlockedStatus - - -def test_install_with_all_inputs(harness): +def test_validate_default_s3_bucket__bucket_name_invalid(harness, mocker): + mocked_validate_s3_bucket_name = mocker.patch("charm.validate_s3_bucket_name") + mocked_validate_s3_bucket_name.return_value = False + obj_storage = {} + harness.begin() + with pytest.raises(CheckFailedError) as raised: + harness.charm._validate_default_s3_bucket(obj_storage=obj_storage) + + assert raised.value.status_type == BlockedStatus + + +@pytest.fixture() +def mocked_S3BucketWrapper(mocker): # noqa: N802 + mocked_s3_bucket_wrapper_class = mocker.patch("charm.S3BucketWrapper") + mocked_s3bucketwrapper_instance = MagicMock() + mocked_s3_bucket_wrapper_class.return_value = mocked_s3bucketwrapper_instance + return mocked_s3_bucket_wrapper_class, mocked_s3bucketwrapper_instance + + +@pytest.fixture() +def bucket_name_valid(): + return "some-valid-bucket-name" + + +@pytest.fixture() +def sample_object_storage(): + return { + "access-key": "access-key-value", + "secret-key": "secret-key-value", + "service": "service-value", + "port": "port-value", + } + + +def test_validate_default_s3_bucket__bucket_is_accessible( + harness, mocked_S3BucketWrapper, bucket_name_valid, sample_object_storage # noqa: N803 +): + bucket_name = bucket_name_valid + obj_storage = sample_object_storage + + # Mocking and setup + mocked_s3bucketwrapper_class, mocked_s3bucketwrapper_instance = mocked_S3BucketWrapper + mocked_s3bucketwrapper_instance.check_if_bucket_accessible.return_value = True + + harness.update_config( + { + "default_artifact_root": bucket_name, + } + ) + harness.begin() + + # Run the code + returned_bucket_name = harness.charm._validate_default_s3_bucket(obj_storage=obj_storage) + mocked_s3bucketwrapper_instance.check_if_bucket_accessible.assert_called_with(bucket_name) + + # Check that everything worked as expected + assert returned_bucket_name == bucket_name + + +def test_validate_default_s3_bucket__missing__do_not_create_if_missing( + harness, mocked_S3BucketWrapper, bucket_name_valid, sample_object_storage # noqa: N803 +): + bucket_name = bucket_name_valid + obj_storage = sample_object_storage + + # Mocking and setup + mocked_s3bucketwrapper_class, mocked_s3bucketwrapper_instance = mocked_S3BucketWrapper + mocked_s3bucketwrapper_instance.check_if_bucket_accessible.return_value = False + + harness.update_config( + { + "create_default_artifact_root_if_missing": False, + "default_artifact_root": bucket_name, + } + ) + harness.begin() + + # Run the code + with pytest.raises(CheckFailedError) as raised: + harness.charm._validate_default_s3_bucket(obj_storage=obj_storage) + + # Check that everything worked as expected + mocked_s3bucketwrapper_class.assert_called_with( + access_key=obj_storage["access-key"], + secret_access_key=obj_storage["secret-key"], + s3_service=obj_storage["service"], + s3_port=obj_storage["port"], + ) + + mocked_s3bucketwrapper_instance.check_if_bucket_accessible.assert_called_with(bucket_name) + + assert raised.value.status_type == BlockedStatus + assert ( + "Set create_default_artifact_root_if_missing=True to automatically create" + in raised.value.msg + ) + + +def test_validate_default_s3_bucket__missing__fail_to_create_if_missing( + harness, mocked_S3BucketWrapper, bucket_name_valid, sample_object_storage # noqa: N803 +): + bucket_name = bucket_name_valid + obj_storage = sample_object_storage + + # Mocking and setup + mocked_s3bucketwrapper_class, mocked_s3bucketwrapper_instance = mocked_S3BucketWrapper + mocked_s3bucketwrapper_instance.check_if_bucket_accessible.return_value = False + mocked_s3bucketwrapper_instance.create_bucket.side_effect = Exception("something went wrong") + + harness.update_config( + { + "create_default_artifact_root_if_missing": True, + "default_artifact_root": bucket_name, + } + ) + harness.begin() + + # Run the code + with pytest.raises(CheckFailedError) as raised: + harness.charm._validate_default_s3_bucket(obj_storage=obj_storage) + + # Check that everything worked as expected + mocked_s3bucketwrapper_class.assert_called_with( + access_key=obj_storage["access-key"], + secret_access_key=obj_storage["secret-key"], + s3_service=obj_storage["service"], + s3_port=obj_storage["port"], + ) + + mocked_s3bucketwrapper_instance.check_if_bucket_accessible.assert_called_with(bucket_name) + mocked_s3bucketwrapper_instance.create_bucket.assert_called_with(bucket_name) + + assert raised.value.status_type == BlockedStatus + assert "bucket not accessible or cannot be created" in raised.value.msg + + +def test_validate_default_s3_bucket__missing__create_if_missing( + harness, mocked_S3BucketWrapper, bucket_name_valid, sample_object_storage # noqa: N803 +): + bucket_name = bucket_name_valid + obj_storage = sample_object_storage + + # Mocking and setup + mocked_s3bucketwrapper_class, mocked_s3bucketwrapper_instance = mocked_S3BucketWrapper + mocked_s3bucketwrapper_instance.check_if_bucket_accessible.return_value = False + mocked_s3bucketwrapper_instance.create_bucket.return_value = bucket_name + + harness.update_config( + { + "create_default_artifact_root_if_missing": True, + "default_artifact_root": bucket_name, + } + ) + harness.begin() + + # Run the code + returned = harness.charm._validate_default_s3_bucket(obj_storage=obj_storage) + + # Check that everything worked as expected + mocked_s3bucketwrapper_class.assert_called_with( + access_key=obj_storage["access-key"], + secret_access_key=obj_storage["secret-key"], + s3_service=obj_storage["service"], + s3_port=obj_storage["port"], + ) + + mocked_s3bucketwrapper_instance.check_if_bucket_accessible.assert_called_with(bucket_name) + mocked_s3bucketwrapper_instance.create_bucket.assert_called_with(bucket_name) + + assert returned == bucket_name + + +def test_install_with_all_inputs(harness, mocker): harness.set_leader(True) harness.add_oci_resource( "oci-image", @@ -119,6 +269,11 @@ def test_install_with_all_inputs(harness): ingress_rel_id, f"{ingress_relation_name}-subscriber", relation_version_data ) + # Mock away _validate_default_s3_bucket to avoid using boto3/creating clients + mocked_validate_default_s3_bucket = mocker.patch("charm.Operator._validate_default_s3_bucket") + bucket_name = harness._backend.config_get()["default_artifact_root"] + mocked_validate_default_s3_bucket.return_value = bucket_name + # pod defaults relations setup pod_defaults_rel_name = "pod-defaults" pod_defaults_rel_id = harness.add_relation( diff --git a/charms/mlflow-server/tests/unit/test_s3.py b/charms/mlflow-server/tests/unit/test_s3.py new file mode 100644 index 00000000..2e9fa00a --- /dev/null +++ b/charms/mlflow-server/tests/unit/test_s3.py @@ -0,0 +1,120 @@ +from contextlib import nullcontext as does_not_raise + +import botocore.exceptions +import pytest +from pytest_lazyfixture import lazy_fixture + +from services.s3 import S3BucketWrapper, validate_s3_bucket_name + + +@pytest.mark.parametrize( + "name,returned", + [ + # Note, this is a non-exhaustive list + ("some-valid-name", True), + ("0123456789", True), + ("01", False), # name too short + ("x" * 64, False), # name too long + ("some_invalid_name", False), # name has '_' + ("some;invalid;name" * 64, False), # name has special characters + ("Some-Invalid-Name", False), # name has capitals + ], +) +def test_validate_s3_bucket_name(name, returned): + assert returned == validate_s3_bucket_name(name) + + +# autouse to prevent calling out to an external service +@pytest.fixture(autouse=True) +def mocked_boto3_client(mocker): + boto3_client_instance = mocker.MagicMock() + boto3_client_class = mocker.patch("boto3.client") + boto3_client_class.return_value = boto3_client_instance + yield boto3_client_instance + + +@pytest.fixture(scope="function") +def client_bucket_accessible(mocked_boto3_client): + mocked_boto3_client.head_bucket.return_value = True + yield mocked_boto3_client + + +@pytest.fixture(scope="function") +def client_accessible_emitting_ClientError(mocked_boto3_client): # noqa: N802 + mocked_boto3_client.head_bucket.side_effect = botocore.exceptions.ClientError({}, "test") + yield mocked_boto3_client + + +@pytest.fixture(scope="function") +def client_accessible_emitting_unknown_exception(mocked_boto3_client): + mocked_boto3_client.head_bucket.side_effect = Exception("some unexpected error") + yield mocked_boto3_client + + +@pytest.fixture(scope="function") +def s3_wrapper_empty(): + wrapper = S3BucketWrapper( + access_key="", + secret_access_key="", + s3_service="", + s3_port="", + ) + return wrapper + + +@pytest.mark.parametrize( + "expected_returned,mocked_client,context_raised", + [ + (True, lazy_fixture("client_bucket_accessible"), does_not_raise()), + ( + False, + lazy_fixture("client_accessible_emitting_ClientError"), + does_not_raise(), + ), # A handled error, returning False + ( + None, + lazy_fixture("client_accessible_emitting_unknown_exception"), + pytest.raises(Exception), + ), + ], +) +def test_check_if_bucket_accessible( + expected_returned, mocked_client, context_raised, s3_wrapper_empty +): + + with context_raised: + s3_wrapper_empty._client = mocked_client + + bucket_name = "some_bucket" + returned = s3_wrapper_empty.check_if_bucket_accessible(bucket_name) + assert returned == expected_returned + + s3_wrapper_empty.client.head_bucket.assert_called_with(Bucket=bucket_name) + + +@pytest.mark.parametrize( + "is_bucket_accessible,", + [ + (True,), + (False,), + ], +) +def test_create_bucket_if_not_exists( + is_bucket_accessible, mocked_boto3_client, mocker, s3_wrapper_empty +): + mocked_check_if_bucket_accessible = mocker.patch( + "services.s3.S3BucketWrapper.check_if_bucket_accessible" + ) + mocked_check_if_bucket_accessible.return_value = is_bucket_accessible + + bucket_name = "some_bucket" + s3_wrapper_empty.create_bucket_if_missing(bucket_name) + + mocked_check_if_bucket_accessible.assert_called_with(bucket_name=bucket_name) + + if is_bucket_accessible: + # Bucket already existed, so we do not create + mocked_boto3_client.create_bucket.assert_not_called() + else: + # Bucket not accessible, so we try to create + mocked_boto3_client.create_bucket.assert_called_once_with(Bucket=bucket_name)