diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 963da9ea76a..daeae60a8a2 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -20,6 +20,7 @@ import os import re import sys +import time from datetime import datetime, timezone from functools import lru_cache from pathlib import Path @@ -2687,47 +2688,48 @@ def delete_artifact(self, artifact_id: UUID) -> None: # -------------------- Artifact Versions -------------------- def _get_or_create_artifact_for_name( - self, session: Session, name: str, has_custom_name: bool + self, name: str, has_custom_name: bool ) -> ArtifactSchema: """Get or create an artifact with a specific name. Args: - session: DB session. name: The artifact name. has_custom_name: Whether the artifact has a custom name. Returns: Schema of the artifact. """ - artifact_query = select(ArtifactSchema).where( - ArtifactSchema.name == name - ) - artifact = session.exec(artifact_query).first() + with Session(self.engine) as session: + artifact_query = select(ArtifactSchema).where( + ArtifactSchema.name == name + ) + artifact = session.exec(artifact_query).first() - if artifact is None: - try: - with session.begin_nested(): - artifact_request = ArtifactRequest( - name=name, has_custom_name=has_custom_name - ) - artifact = ArtifactSchema.from_request(artifact_request) - session.add(artifact) - session.commit() - session.refresh(artifact) - except IntegrityError: - # We failed to create the artifact due to the unique constraint - # for artifact names -> The artifact was already created, we can - # just fetch it from the DB now - artifact = session.exec(artifact_query).one() - - if artifact.has_custom_name is False and has_custom_name: - # If a new version with custom name was created for an artifact - # that previously had no custom name, we update it. - artifact.has_custom_name = True - session.commit() - session.refresh(artifact) + if artifact is None: + try: + with session.begin_nested(): + artifact_request = ArtifactRequest( + name=name, has_custom_name=has_custom_name + ) + artifact = ArtifactSchema.from_request( + artifact_request + ) + session.add(artifact) + session.commit() + session.refresh(artifact) + except IntegrityError: + # We failed to create the artifact due to the unique constraint + # for artifact names -> The artifact was already created, we can + # just fetch it from the DB now + artifact = session.exec(artifact_query).one() + + if artifact.has_custom_name is False and has_custom_name: + # If a new version with custom name was created for an artifact + # that previously had no custom name, we update it. + artifact.has_custom_name = True + session.commit() - return artifact + return artifact def _get_next_numeric_version_for_artifact( self, session: Session, artifact_id: UUID @@ -2767,70 +2769,84 @@ def create_artifact_version( Returns: The created artifact version. """ - with Session(self.engine) as session: - if artifact_name := artifact_version.artifact_name: - artifact_schema = self._get_or_create_artifact_for_name( - session=session, - name=artifact_name, - has_custom_name=artifact_version.has_custom_name, - ) - artifact_version.artifact_id = artifact_schema.id + if artifact_name := artifact_version.artifact_name: + artifact_schema = self._get_or_create_artifact_for_name( + name=artifact_name, + has_custom_name=artifact_version.has_custom_name, + ) + artifact_version.artifact_id = artifact_schema.id - assert artifact_version.artifact_id + assert artifact_version.artifact_id - if artifact_version.version is None: - # No explicit version in the request -> We will try to - # auto-increment the numeric version of the artifact version - remaining_tries = MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION - while remaining_tries > 0: - remaining_tries -= 1 - try: - with session.begin_nested(): - artifact_version.version = str( - self._get_next_numeric_version_for_artifact( - session=session, - artifact_id=artifact_version.artifact_id, - ) + if artifact_version.version is None: + # No explicit version in the request -> We will try to + # auto-increment the numeric version of the artifact version + remaining_tries = MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION + while remaining_tries > 0: + remaining_tries -= 1 + try: + with Session(self.engine) as session: + artifact_version.version = str( + self._get_next_numeric_version_for_artifact( + session=session, + artifact_id=artifact_version.artifact_id, ) + ) - artifact_version_schema = ( - ArtifactVersionSchema.from_request( - artifact_version - ) - ) - session.add(artifact_version_schema) - session.commit() - except IntegrityError: - if remaining_tries == 0: - raise EntityExistsError( - f"Failed to create version for artifact " - f"{artifact_schema.name}. This is most likely " - "caused by multiple parallel requests that try " - "to create versions for this artifact in the " - "database." + artifact_version_schema = ( + ArtifactVersionSchema.from_request( + artifact_version ) + ) + session.add(artifact_version_schema) + session.commit() + except IntegrityError: + if remaining_tries == 0: + raise EntityExistsError( + f"Failed to create version for artifact " + f"{artifact_schema.name}. This is most likely " + "caused by multiple parallel requests that try " + "to create versions for this artifact in the " + "database." + ) else: - break - session.refresh(artifact_version_schema) - else: - # An explicit version was specified for the artifact version. - # We don't do any incrementing and fail immediately if the - # version already exists. + # Exponential backoff to account for heavy + # parallelization + sleep_duration = 0.05 * 1.5 ** ( + MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION + - remaining_tries + ) + logger.debug( + "Failed to create artifact version %s " + "(version %s) due to an integrity error. " + "Retrying in %f seconds.", + artifact_schema.name, + artifact_version.version, + sleep_duration, + ) + time.sleep(sleep_duration) + else: + break + else: + # An explicit version was specified for the artifact version. + # We don't do any incrementing and fail immediately if the + # version already exists. + with Session(self.engine) as session: try: artifact_version_schema = ( ArtifactVersionSchema.from_request(artifact_version) ) session.add(artifact_version_schema) session.commit() - session.refresh(artifact_version_schema) except IntegrityError: raise EntityExistsError( f"Unable to create artifact version " - f"{artifact_schema.name}({artifact_version.version}: " - "An artifact with the same name and version already " - "exists." + f"{artifact_schema.name} (version " + f"{artifact_version.version}): An artifact with the " + "same name and version already exists." ) + with Session(self.engine) as session: # Save visualizations of the artifact if artifact_version.visualizations: for vis in artifact_version.visualizations: @@ -2863,7 +2879,11 @@ def create_artifact_version( session.add(run_metadata_schema) session.commit() - session.refresh(artifact_version_schema) + artifact_version_schema = session.exec( + select(ArtifactVersionSchema).where( + ArtifactVersionSchema.id == artifact_version_schema.id + ) + ).one() return artifact_version_schema.to_model( include_metadata=True, include_resources=True diff --git a/tests/integration/functional/artifacts/test_utils.py b/tests/integration/functional/artifacts/test_utils.py index 29015e338b5..43f159d1989 100644 --- a/tests/integration/functional/artifacts/test_utils.py +++ b/tests/integration/functional/artifacts/test_utils.py @@ -365,10 +365,6 @@ def test_parallel_artifact_creation(clean_client: Client): name="meaning_of_life", size=min(1000, process_count * 10) ) assert len(avs) == process_count - print( - {str(i) for i in range(1, process_count + 1)} - - {av.version for av in avs} - ) assert {av.version for av in avs} == { str(i) for i in range(1, process_count + 1) }