From e80b6a07ea925a92e7950eef0975113f0a6aa90e Mon Sep 17 00:00:00 2001 From: Noctua Date: Thu, 17 Oct 2024 20:01:54 +0000 Subject: [PATCH] chore: update charm libraries --- .../v0/kubernetes_compute_resources_patch.py | 206 ++++++++++++++++-- .../observability_libs/v1/cert_handler.py | 70 +++++- .../tempo_coordinator_k8s/v0/charm_tracing.py | 5 +- .../tempo_coordinator_k8s/v0/tracing.py | 9 +- .../v3/tls_certificates.py | 122 +++++++---- 5 files changed, 347 insertions(+), 65 deletions(-) diff --git a/lib/charms/observability_libs/v0/kubernetes_compute_resources_patch.py b/lib/charms/observability_libs/v0/kubernetes_compute_resources_patch.py index 2ab8a22c..34dd0264 100644 --- a/lib/charms/observability_libs/v0/kubernetes_compute_resources_patch.py +++ b/lib/charms/observability_libs/v0/kubernetes_compute_resources_patch.py @@ -4,7 +4,7 @@ """# KubernetesComputeResourcesPatch Library. This library is designed to enable developers to more simply patch the Kubernetes compute resource -limits and requests created by Juju during the deployment of a sidecar charm. +limits and requests created by Juju during the deployment of a charm. When initialised, this library binds a handler to the parent charm's `config-changed` event. The config-changed event is used because it is guaranteed to fire on startup, on upgrade and on @@ -76,6 +76,17 @@ def _resource_spec_from_config(self) -> ResourceRequirements: return ResourceRequirements(limits=spec, requests=spec) ``` +If you wish to pull the state of the resources patch operation and set the charm unit status based on that patch result, +you can achieve that using `get_status()` function. +```python +class SomeCharm(CharmBase): + def __init__(self, *args): + #... + self.framework.observe(self.on.collect_unit_status, self._on_collect_unit_status) + #... + def _on_collect_unit_status(self, event: CollectStatusEvent): + event.add_status(self.resources_patch.get_status()) +``` Additionally, you may wish to use mocks in your charm's unit testing to ensure that the library does not try to make any API calls, or open any files during testing that are unlikely to be @@ -83,12 +94,14 @@ def _resource_spec_from_config(self) -> ResourceRequirements: ```python # ... +from ops import ActiveStatus @patch.multiple( "charm.KubernetesComputeResourcesPatch", _namespace="test-namespace", _is_patched=lambda *a, **kw: True, is_ready=lambda *a, **kw: True, + get_status=lambda _: ActiveStatus(), ) @patch("lightkube.core.client.GenericSyncClient") def setUp(self, *unused): @@ -105,8 +118,9 @@ def setUp(self, *unused): import logging from decimal import Decimal from math import ceil, floor -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import tenacity from lightkube import ApiError, Client # pyright: ignore from lightkube.core import exceptions from lightkube.models.apps_v1 import StatefulSetSpec @@ -120,8 +134,10 @@ def setUp(self, *unused): from lightkube.resources.core_v1 import Pod from lightkube.types import PatchType from lightkube.utils.quantity import equals_canonically, parse_quantity +from ops import ActiveStatus, BlockedStatus, WaitingStatus from ops.charm import CharmBase from ops.framework import BoundEvent, EventBase, EventSource, Object, ObjectEvents +from ops.model import StatusBase logger = logging.getLogger(__name__) @@ -133,14 +149,16 @@ def setUp(self, *unused): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 7 +LIBPATCH = 8 _Decimal = Union[Decimal, float, str, int] # types that are potentially convertible to Decimal def adjust_resource_requirements( - limits: Optional[dict], requests: Optional[dict], adhere_to_requests: bool = True + limits: Optional[Dict[Any, Any]], + requests: Optional[Dict[Any, Any]], + adhere_to_requests: bool = True, ) -> ResourceRequirements: """Adjust resource limits so that `limits` and `requests` are consistent with each other. @@ -289,6 +307,18 @@ def sanitize_resource_spec_dict(spec: Optional[dict]) -> Optional[dict]: return d +def _retry_on_condition(exception): + """Retry if the exception is an ApiError with a status code != 403. + + Returns: a boolean value to indicate whether to retry or not. + """ + if isinstance(exception, ApiError) and str(exception.status.code) != "403": + return True + if isinstance(exception, exceptions.ConfigError) or isinstance(exception, ValueError): + return True + return False + + class K8sResourcePatchFailedEvent(EventBase): """Emitted when patching fails.""" @@ -385,27 +415,132 @@ def get_actual(self, pod_name: str) -> Optional[ResourceRequirements]: ) return podspec.resources + def is_failed( + self, resource_reqs_func: Callable[[], ResourceRequirements] + ) -> Tuple[bool, str]: + """Returns a tuple indicating whether a patch operation has failed along with a failure message. + + Implementation is based on dry running the patch operation to catch if there would be failures (e.g: Wrong spec and Auth errors). + """ + try: + resource_reqs = resource_reqs_func() + limits = resource_reqs.limits + requests = resource_reqs.requests + except ValueError as e: + msg = f"Failed obtaining resource limit spec: {e}" + logger.error(msg) + return True, msg + + # Dry run does not catch negative values for resource requests and limits. + if not is_valid_spec(limits) or not is_valid_spec(requests): + msg = f"Invalid resource requirements specs: {limits}, {requests}" + logger.error(msg) + return True, msg + + resource_reqs = ResourceRequirements( + limits=sanitize_resource_spec_dict(limits), # type: ignore[arg-type] + requests=sanitize_resource_spec_dict(requests), # type: ignore[arg-type] + ) + + try: + self.apply(resource_reqs, dry_run=True) + except ApiError as e: + if e.status.code == 403: + msg = f"Kubernetes resources patch failed: `juju trust` this application. {e}" + else: + msg = f"Kubernetes resources patch failed: {e}" + return True, msg + except ValueError as e: + msg = f"Kubernetes resources patch failed: {e}" + return True, msg + + return False, "" + + def is_in_progress(self) -> bool: + """Returns a boolean to indicate whether a patch operation is in progress. + + Implementation follows a similar approach to `kubectl rollout status statefulset` to track the progress of a rollout. + Reference: https://github.com/kubernetes/kubectl/blob/kubernetes-1.31.0/pkg/polymorphichelpers/rollout_status.go + """ + try: + sts = self.client.get( + StatefulSet, name=self.statefulset_name, namespace=self.namespace + ) + except (ValueError, ApiError) as e: + # Assumption: if there was a persistent issue, it'd have been caught in `is_failed` + # Wait until next run to try again. + logger.error(f"Failed to fetch statefulset from K8s api: {e}") + return False + + if sts.status is None or sts.spec is None: + logger.debug("status/spec are not yet available") + return False + if sts.status.observedGeneration == 0 or ( + sts.metadata + and sts.status.observedGeneration + and sts.metadata.generation + and sts.metadata.generation > sts.status.observedGeneration + ): + logger.debug("waiting for statefulset spec update to be observed...") + return True + if ( + sts.spec.replicas is not None + and sts.status.readyReplicas is not None + and sts.status.readyReplicas < sts.spec.replicas + ): + logger.debug( + f"Waiting for {sts.spec.replicas-sts.status.readyReplicas} pods to be ready..." + ) + return True + + if ( + sts.spec.updateStrategy + and sts.spec.updateStrategy.type == "rollingUpdate" + and sts.spec.updateStrategy.rollingUpdate is not None + ): + if ( + sts.spec.replicas is not None + and sts.spec.updateStrategy.rollingUpdate.partition is not None + ): + if sts.status.updatedReplicas and sts.status.updatedReplicas < ( + sts.spec.replicas - sts.spec.updateStrategy.rollingUpdate.partition + ): + logger.debug( + f"Waiting for partitioned roll out to finish: {sts.status.updatedReplicas} out of {sts.spec.replicas - sts.spec.updateStrategy.rollingUpdate.partition} new pods have been updated..." + ) + return True + logger.debug( + f"partitioned roll out complete: {sts.status.updatedReplicas} new pods have been updated..." + ) + return False + + if sts.status.updateRevision != sts.status.currentRevision: + logger.debug( + f"waiting for statefulset rolling update to complete {sts.status.updatedReplicas} pods at revision {sts.status.updateRevision}..." + ) + return True + + logger.debug( + f"statefulset rolling update complete pods at revision {sts.status.currentRevision}" + ) + return False + def is_ready(self, pod_name, resource_reqs: ResourceRequirements): """Reports if the resource patch has been applied and is in effect. Returns: bool: A boolean indicating if the service patch has been applied and is in effect. """ - logger.info( - "reqs=%s, templated=%s, actual=%s", - resource_reqs, - self.get_templated(), - self.get_actual(pod_name), - ) return self.is_patched(resource_reqs) and equals_canonically( # pyright: ignore resource_reqs, self.get_actual(pod_name) # pyright: ignore ) - def apply(self, resource_reqs: ResourceRequirements) -> None: + def apply(self, resource_reqs: ResourceRequirements, dry_run=False) -> None: """Patch the Kubernetes resources created by Juju to limit cpu or mem.""" # Need to ignore invalid input, otherwise the StatefulSet gives "FailedCreate" and the # charm would be stuck in unknown/lost. - if self.is_patched(resource_reqs): + if not dry_run and self.is_patched(resource_reqs): + logger.debug(f"Resource requests are already patched: {resource_reqs}") return self.client.patch( @@ -415,6 +550,7 @@ def apply(self, resource_reqs: ResourceRequirements) -> None: namespace=self.namespace, patch_type=PatchType.APPLY, field_manager=self.__class__.__name__, + dry_run=dry_run, ) @@ -422,6 +558,9 @@ class KubernetesComputeResourcesPatch(Object): """A utility for patching the Kubernetes compute resources set up by Juju.""" on = K8sResourcePatchEvents() # pyright: ignore + PATCH_RETRY_STOP = tenacity.stop_after_delay(20) + PATCH_RETRY_WAIT = tenacity.wait_fixed(5) + PATCH_RETRY_IF = tenacity.retry_if_exception(_retry_on_condition) def __init__( self, @@ -468,7 +607,11 @@ def _on_config_changed(self, _): self._patch() def _patch(self) -> None: - """Patch the Kubernetes resources created by Juju to limit cpu or mem.""" + """Patch the Kubernetes resources created by Juju to limit cpu or mem. + + This method will keep on retrying to patch the kubernetes resource for a default duration of 20 seconds + if the patching failure is due to a recoverable error (e.g: Network Latency). + """ try: resource_reqs = self.resource_reqs_func() limits = resource_reqs.limits @@ -492,7 +635,18 @@ def _patch(self) -> None: ) try: - self.patcher.apply(resource_reqs) + for attempt in tenacity.Retrying( + retry=self.PATCH_RETRY_IF, + stop=self.PATCH_RETRY_STOP, + wait=self.PATCH_RETRY_WAIT, + # if you don't succeed raise the last caught exception when you're done + reraise=True, + ): + with attempt: + logger.debug( + f"attempt #{attempt.retry_state.attempt_number} to patch resource limits" + ) + self.patcher.apply(resource_reqs) except exceptions.ConfigError as e: msg = f"Error creating k8s client: {e}" @@ -503,6 +657,7 @@ def _patch(self) -> None: except ApiError as e: if e.status.code == 403: msg = f"Kubernetes resources patch failed: `juju trust` this application. {e}" + else: msg = f"Kubernetes resources patch failed: {e}" @@ -554,6 +709,29 @@ def is_ready(self) -> bool: self.on.patch_failed.emit(message=msg) return False + def get_status(self) -> StatusBase: + """Return the status of patching the resource limits in a `StatusBase` format. + + Returns: + StatusBase: There is a 1:1 mapping between the state of the patching operation and a `StatusBase` value that the charm can be set to. + Possible values are: + - ActiveStatus: The patch was applied successfully. + - BlockedStatus: The patch failed and requires a human intervention. + - WaitingStatus: The patch is still in progress. + + Example: + - ActiveStatus("Patch applied successfully") + - BlockedStatus("Failed due to missing permissions") + - WaitingStatus("Patch is in progress") + """ + failed, msg = self.patcher.is_failed(self.resource_reqs_func) + if failed: + return BlockedStatus(msg) + if self.patcher.is_in_progress(): + return WaitingStatus("waiting for resources patch to apply") + # patch successful or nothing has been patched yet + return ActiveStatus() + @property def _app(self) -> str: """Name of the current Juju application. diff --git a/lib/charms/observability_libs/v1/cert_handler.py b/lib/charms/observability_libs/v1/cert_handler.py index 3b87ad46..26be8793 100644 --- a/lib/charms/observability_libs/v1/cert_handler.py +++ b/lib/charms/observability_libs/v1/cert_handler.py @@ -26,12 +26,13 @@ self.framework.observe(self.cert_handler.on.cert_changed, self._on_server_cert_changed) container.push(keypath, self.cert_handler.private_key) -container.push(certpath, self.cert_handler.servert_cert) +container.push(certpath, self.cert_handler.server_cert) ``` Since this library uses [Juju Secrets](https://juju.is/docs/juju/secret) it requires Juju >= 3.0.3. """ import abc +import hashlib import ipaddress import json import socket @@ -59,7 +60,7 @@ import logging from ops.charm import CharmBase -from ops.framework import EventBase, EventSource, Object, ObjectEvents +from ops.framework import BoundEvent, EventBase, EventSource, Object, ObjectEvents, StoredState from ops.jujuversion import JujuVersion from ops.model import Relation, Secret, SecretNotFoundError @@ -67,7 +68,7 @@ LIBID = "b5cd5cd580f3428fa5f59a8876dcbe6a" LIBAPI = 1 -LIBPATCH = 11 +LIBPATCH = 14 VAULT_SECRET_LABEL = "cert-handler-private-vault" @@ -273,6 +274,7 @@ class CertHandler(Object): """A wrapper for the requirer side of the TLS Certificates charm library.""" on = CertHandlerEvents() # pyright: ignore + _stored = StoredState() def __init__( self, @@ -283,6 +285,7 @@ def __init__( peer_relation_name: str = "peers", cert_subject: Optional[str] = None, sans: Optional[List[str]] = None, + refresh_events: Optional[List[BoundEvent]] = None, ): """CertHandler is used to wrap TLS Certificates management operations for charms. @@ -299,8 +302,14 @@ def __init__( Must match metadata.yaml. cert_subject: Custom subject. Name collisions are under the caller's responsibility. sans: DNS names. If none are given, use FQDN. + refresh_events: [DEPRECATED]. """ super().__init__(charm, key) + # use StoredState to store the hash of the CSR + # to potentially trigger a CSR renewal + self._stored.set_default( + csr_hash=None, + ) self.charm = charm # We need to sanitize the unit name, otherwise route53 complains: @@ -309,8 +318,9 @@ def __init__( # Use fqdn only if no SANs were given, and drop empty/duplicate SANs sans = list(set(filter(None, (sans or [socket.getfqdn()])))) - self.sans_ip = list(filter(is_ip_address, sans)) - self.sans_dns = list(filterfalse(is_ip_address, sans)) + # sort SANS lists to avoid unnecessary csr renewals during reconciliation + self.sans_ip = sorted(filter(is_ip_address, sans)) + self.sans_dns = sorted(filterfalse(is_ip_address, sans)) if self._check_juju_supports_secrets(): vault_backend = _SecretVaultBackend(charm, secret_label=VAULT_SECRET_LABEL) @@ -355,6 +365,17 @@ def __init__( self._on_upgrade_charm, ) + if refresh_events: + logger.warn( + "DEPRECATION WARNING. `refresh_events` is now deprecated. CertHandler will automatically refresh the CSR when necessary." + ) + + self._reconcile() + + def _reconcile(self): + """Run all logic that is independent of what event we're processing.""" + self._refresh_csr_if_needed() + def _on_upgrade_charm(self, _): has_privkey = self.vault.get_value("private-key") @@ -368,6 +389,11 @@ def _on_upgrade_charm(self, _): # this will call `self.private_key` which will generate a new privkey. self._generate_csr(renew=True) + def _refresh_csr_if_needed(self): + """Refresh the current CSR with a new one if there are any SANs changes.""" + if self._stored.csr_hash is not None and self._stored.csr_hash != self._csr_hash: + self._generate_csr(renew=True) + def _migrate_vault(self): peer_backend = _RelationVaultBackend(self.charm, relation_name="peers") @@ -419,6 +445,24 @@ def enabled(self) -> bool: return True + @property + def _csr_hash(self) -> str: + """A hash of the config that constructs the CSR. + + Only include here the config options that, should they change, should trigger a renewal of + the CSR. + """ + + def _stable_hash(data): + return hashlib.sha256(str(data).encode()).hexdigest() + + return _stable_hash( + ( + tuple(self.sans_dns), + tuple(self.sans_ip), + ) + ) + @property def available(self) -> bool: """Return True if all certs are available in relation data; False otherwise.""" @@ -484,6 +528,8 @@ def _generate_csr( ) self.certificates.request_certificate_creation(certificate_signing_request=csr) + self._stored.csr_hash = self._csr_hash + if clear_cert: self.vault.clear() @@ -548,9 +594,19 @@ def server_cert(self) -> Optional[str]: @property def chain(self) -> Optional[str]: - """Return the ca chain bundled as a single PEM string.""" + """Return the entire chain bundled as a single PEM string. This includes, if available, the certificate, intermediate CAs, and the root CA. + + If the server certificate is not set in the chain by the provider, we'll add it + to the top of the chain so that it could be used by a server. + """ cert = self.get_cert() - return cert.chain_as_pem() if cert else None + if not cert: + return None + chain = cert.chain_as_pem() + if cert.certificate not in chain: + # add server cert to chain + chain = cert.certificate + "\n\n" + chain + return chain def _on_certificate_expiring( self, event: Union[CertificateExpiringEvent, CertificateInvalidatedEvent] diff --git a/lib/charms/tempo_coordinator_k8s/v0/charm_tracing.py b/lib/charms/tempo_coordinator_k8s/v0/charm_tracing.py index 1e7ff840..3aea50f0 100644 --- a/lib/charms/tempo_coordinator_k8s/v0/charm_tracing.py +++ b/lib/charms/tempo_coordinator_k8s/v0/charm_tracing.py @@ -69,6 +69,9 @@ def my_tracing_endpoint(self) -> Optional[str]: - every event as a span (including custom events) - every charm method call (except dunders) as a span +We recommend that you scale up your tracing provider and relate it to an ingress so that your tracing requests +go through the ingress and get load balanced across all units. Otherwise, if the provider's leader goes down, your tracing goes down. + ## TLS support If your charm integrates with a TLS provider which is also trusted by the tracing provider (the Tempo charm), @@ -269,7 +272,7 @@ def _remove_stale_otel_sdk_packages(): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 2 +LIBPATCH = 3 PYDEPS = ["opentelemetry-exporter-otlp-proto-http==1.21.0"] diff --git a/lib/charms/tempo_coordinator_k8s/v0/tracing.py b/lib/charms/tempo_coordinator_k8s/v0/tracing.py index 1f92867f..2035dffd 100644 --- a/lib/charms/tempo_coordinator_k8s/v0/tracing.py +++ b/lib/charms/tempo_coordinator_k8s/v0/tracing.py @@ -34,7 +34,7 @@ def __init__(self, *args): `TracingEndpointRequirer.request_protocols(*protocol:str, relation:Optional[Relation])` method. Using this method also allows you to use per-relation protocols. -Units of provider charms obtain the tempo endpoint to which they will push their traces by calling +Units of requirer charms obtain the tempo endpoint to which they will push their traces by calling `TracingEndpointRequirer.get_endpoint(protocol: str)`, where `protocol` is, for example: - `otlp_grpc` - `otlp_http` @@ -44,7 +44,10 @@ def __init__(self, *args): If the `protocol` is not in the list of protocols that the charm requested at endpoint set-up time, the library will raise an error. -## Requirer Library Usage +We recommend that you scale up your tracing provider and relate it to an ingress so that your tracing requests +go through the ingress and get load balanced across all units. Otherwise, if the provider's leader goes down, your tracing goes down. + +## Provider Library Usage The `TracingEndpointProvider` object may be used by charms to manage relations with their trace sources. For this purposes a Tempo-like charm needs to do two things @@ -107,7 +110,7 @@ def __init__(self, *args): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 2 +LIBPATCH = 3 PYDEPS = ["pydantic"] diff --git a/lib/charms/tls_certificates_interface/v3/tls_certificates.py b/lib/charms/tls_certificates_interface/v3/tls_certificates.py index 33f34b62..6794c7af 100644 --- a/lib/charms/tls_certificates_interface/v3/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v3/tls_certificates.py @@ -277,13 +277,13 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven """ # noqa: D405, D410, D411, D214, D416 import copy +import ipaddress import json import logging import uuid from contextlib import suppress from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from ipaddress import IPv4Address from typing import List, Literal, Optional, Union from cryptography import x509 @@ -305,6 +305,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven ModelError, Relation, RelationDataContent, + Secret, SecretNotFoundError, Unit, ) @@ -317,7 +318,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 15 +LIBPATCH = 21 PYDEPS = ["cryptography", "jsonschema"] @@ -735,16 +736,16 @@ def calculate_expiry_notification_time( """ if provider_recommended_notification_time is not None: provider_recommended_notification_time = abs(provider_recommended_notification_time) - provider_recommendation_time_delta = ( - expiry_time - timedelta(hours=provider_recommended_notification_time) + provider_recommendation_time_delta = expiry_time - timedelta( + hours=provider_recommended_notification_time ) if validity_start_time < provider_recommendation_time_delta: return provider_recommendation_time_delta if requirer_recommended_notification_time is not None: requirer_recommended_notification_time = abs(requirer_recommended_notification_time) - requirer_recommendation_time_delta = ( - expiry_time - timedelta(hours=requirer_recommended_notification_time) + requirer_recommendation_time_delta = expiry_time - timedelta( + hours=requirer_recommended_notification_time ) if validity_start_time < requirer_recommendation_time_delta: return requirer_recommendation_time_delta @@ -1077,7 +1078,7 @@ def generate_csr( # noqa: C901 if sans_oid: _sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid]) if sans_ip: - _sans.extend([x509.IPAddress(IPv4Address(san)) for san in sans_ip]) + _sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip]) if sans: _sans.extend([x509.DNSName(san) for san in sans]) if sans_dns: @@ -1109,25 +1110,16 @@ def csr_matches_certificate(csr: str, cert: str) -> bool: Returns: bool: True/False depending on whether the CSR matches the certificate. """ - try: - csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) - cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) - - if csr_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) != cert_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ): - return False - if ( - csr_object.public_key().public_numbers().n # type: ignore[union-attr] - != cert_object.public_key().public_numbers().n # type: ignore[union-attr] - ): - return False - except ValueError: - logger.warning("Could not load certificate or CSR.") + csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) + cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) + + if csr_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) != cert_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ): return False return True @@ -1457,18 +1449,31 @@ def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None Returns: None """ - provider_certificates = self.get_provider_certificates(relation_id) - requirer_csrs = self.get_requirer_csrs(relation_id) + provider_certificates = self.get_unsolicited_certificates(relation_id=relation_id) + for provider_certificate in provider_certificates: + self.on.certificate_revocation_request.emit( + certificate=provider_certificate.certificate, + certificate_signing_request=provider_certificate.csr, + ca=provider_certificate.ca, + chain=provider_certificate.chain, + ) + self.remove_certificate(certificate=provider_certificate.certificate) + + def get_unsolicited_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return provider certificates for which no certificate requests exists. + + Those certificates should be revoked. + """ + unsolicited_certificates: List[ProviderCertificate] = [] + provider_certificates = self.get_provider_certificates(relation_id=relation_id) + requirer_csrs = self.get_requirer_csrs(relation_id=relation_id) list_of_csrs = [csr.csr for csr in requirer_csrs] for certificate in provider_certificates: if certificate.csr not in list_of_csrs: - self.on.certificate_revocation_request.emit( - certificate=certificate.certificate, - certificate_signing_request=certificate.csr, - ca=certificate.ca, - chain=certificate.chain, - ) - self.remove_certificate(certificate=certificate.certificate) + unsolicited_certificates.append(certificate) + return unsolicited_certificates def get_outstanding_certificate_requests( self, relation_id: Optional[int] = None @@ -1886,8 +1891,7 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: "Removing secret with label %s", f"{LIBID}-{csr_in_sha256_hex}", ) - secret = self.model.get_secret( - label=f"{LIBID}-{csr_in_sha256_hex}") + secret = self.model.get_secret(label=f"{LIBID}-{csr_in_sha256_hex}") secret.remove_all_revisions() self.on.certificate_invalidated.emit( reason="revoked", @@ -1898,10 +1902,20 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: ) else: try: + secret = self.model.get_secret(label=f"{LIBID}-{csr_in_sha256_hex}") logger.debug( "Setting secret with label %s", f"{LIBID}-{csr_in_sha256_hex}" ) - secret = self.model.get_secret(label=f"{LIBID}-{csr_in_sha256_hex}") + # Juju < 3.6 will create a new revision even if the content is the same + if ( + secret.get_content(refresh=True).get("certificate", "") + == certificate.certificate + ): + logger.debug( + "Secret %s with correct certificate already exists", + f"{LIBID}-{csr_in_sha256_hex}", + ) + return secret.set_content( {"certificate": certificate.certificate, "csr": certificate.csr} ) @@ -1975,17 +1989,26 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None: Args: event (SecretExpiredEvent): Juju event """ - if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-"): + csr = self._get_csr_from_secret(event.secret) + if not csr: + logger.error("Failed to get CSR from secret %s", event.secret.label) return - csr = event.secret.get_content()["csr"] provider_certificate = self._find_certificate_in_relation_data(csr) if not provider_certificate: # A secret expired but we did not find matching certificate. Cleaning up + logger.warning( + "Failed to find matching certificate for csr, cleaning up secret %s", + event.secret.label, + ) event.secret.remove_all_revisions() return if not provider_certificate.expiry_time: # A secret expired but matching certificate is invalid. Cleaning up + logger.warning( + "Certificate matching csr is invalid, cleaning up secret %s", + event.secret.label, + ) event.secret.remove_all_revisions() return @@ -2017,3 +2040,22 @@ def _find_certificate_in_relation_data(self, csr: str) -> Optional[ProviderCerti continue return provider_certificate return None + + def _get_csr_from_secret(self, secret: Secret) -> str | None: + """Extract the CSR from the secret label or content. + + This function is a workaround to maintain backwards compatibility + and fix the issue reported in + https://github.com/canonical/tls-certificates-interface/issues/228 + """ + try: + content = secret.get_content(refresh=True) + except SecretNotFoundError: + return None + if not (csr := content.get("csr", None)): + # In versions <14 of the Lib we were storing the CSR in the label of the secret + # The CSR now is stored int the content of the secret, which was a breaking change + # Here we get the CSR if the secret was created by an app using libpatch 14 or lower + if secret.label and secret.label.startswith(f"{LIBID}-"): + csr = secret.label[len(f"{LIBID}-") :] + return csr