Skip to content

Commit

Permalink
Refactor artifact version creation retries
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi committed Oct 21, 2024
1 parent 7536422 commit f809f02
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 81 deletions.
174 changes: 97 additions & 77 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import re
import sys
import time
from datetime import datetime, timezone
from functools import lru_cache
from pathlib import Path
Expand Down Expand Up @@ -2687,47 +2688,48 @@ def delete_artifact(self, artifact_id: UUID) -> None:
# -------------------- Artifact Versions --------------------

def _get_or_create_artifact_for_name(
self, session: Session, name: str, has_custom_name: bool
self, name: str, has_custom_name: bool
) -> ArtifactSchema:
"""Get or create an artifact with a specific name.
Args:
session: DB session.
name: The artifact name.
has_custom_name: Whether the artifact has a custom name.
Returns:
Schema of the artifact.
"""
artifact_query = select(ArtifactSchema).where(
ArtifactSchema.name == name
)
artifact = session.exec(artifact_query).first()
with Session(self.engine) as session:
artifact_query = select(ArtifactSchema).where(
ArtifactSchema.name == name
)
artifact = session.exec(artifact_query).first()

if artifact is None:
try:
with session.begin_nested():
artifact_request = ArtifactRequest(
name=name, has_custom_name=has_custom_name
)
artifact = ArtifactSchema.from_request(artifact_request)
session.add(artifact)
session.commit()
session.refresh(artifact)
except IntegrityError:
# We failed to create the artifact due to the unique constraint
# for artifact names -> The artifact was already created, we can
# just fetch it from the DB now
artifact = session.exec(artifact_query).one()

if artifact.has_custom_name is False and has_custom_name:
# If a new version with custom name was created for an artifact
# that previously had no custom name, we update it.
artifact.has_custom_name = True
session.commit()
session.refresh(artifact)
if artifact is None:
try:
with session.begin_nested():
artifact_request = ArtifactRequest(
name=name, has_custom_name=has_custom_name
)
artifact = ArtifactSchema.from_request(
artifact_request
)
session.add(artifact)
session.commit()
session.refresh(artifact)
except IntegrityError:
# We failed to create the artifact due to the unique constraint
# for artifact names -> The artifact was already created, we can
# just fetch it from the DB now
artifact = session.exec(artifact_query).one()

if artifact.has_custom_name is False and has_custom_name:
# If a new version with custom name was created for an artifact
# that previously had no custom name, we update it.
artifact.has_custom_name = True
session.commit()

return artifact
return artifact

def _get_next_numeric_version_for_artifact(
self, session: Session, artifact_id: UUID
Expand Down Expand Up @@ -2767,70 +2769,84 @@ def create_artifact_version(
Returns:
The created artifact version.
"""
with Session(self.engine) as session:
if artifact_name := artifact_version.artifact_name:
artifact_schema = self._get_or_create_artifact_for_name(
session=session,
name=artifact_name,
has_custom_name=artifact_version.has_custom_name,
)
artifact_version.artifact_id = artifact_schema.id
if artifact_name := artifact_version.artifact_name:
artifact_schema = self._get_or_create_artifact_for_name(
name=artifact_name,
has_custom_name=artifact_version.has_custom_name,
)
artifact_version.artifact_id = artifact_schema.id

assert artifact_version.artifact_id
assert artifact_version.artifact_id

if artifact_version.version is None:
# No explicit version in the request -> We will try to
# auto-increment the numeric version of the artifact version
remaining_tries = MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION
while remaining_tries > 0:
remaining_tries -= 1
try:
with session.begin_nested():
artifact_version.version = str(
self._get_next_numeric_version_for_artifact(
session=session,
artifact_id=artifact_version.artifact_id,
)
if artifact_version.version is None:
# No explicit version in the request -> We will try to
# auto-increment the numeric version of the artifact version
remaining_tries = MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION
while remaining_tries > 0:
remaining_tries -= 1
try:
with Session(self.engine) as session:
artifact_version.version = str(
self._get_next_numeric_version_for_artifact(
session=session,
artifact_id=artifact_version.artifact_id,
)
)

artifact_version_schema = (
ArtifactVersionSchema.from_request(
artifact_version
)
)
session.add(artifact_version_schema)
session.commit()
except IntegrityError:
if remaining_tries == 0:
raise EntityExistsError(
f"Failed to create version for artifact "
f"{artifact_schema.name}. This is most likely "
"caused by multiple parallel requests that try "
"to create versions for this artifact in the "
"database."
artifact_version_schema = (
ArtifactVersionSchema.from_request(
artifact_version
)
)
session.add(artifact_version_schema)
session.commit()
except IntegrityError:
if remaining_tries == 0:
raise EntityExistsError(
f"Failed to create version for artifact "
f"{artifact_schema.name}. This is most likely "
"caused by multiple parallel requests that try "
"to create versions for this artifact in the "
"database."
)
else:
break
session.refresh(artifact_version_schema)
else:
# An explicit version was specified for the artifact version.
# We don't do any incrementing and fail immediately if the
# version already exists.
# Exponential backoff to account for heavy
# parallelization
sleep_duration = 0.05 * 1.5 ** (
MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION
- remaining_tries
)
logger.debug(
"Failed to create artifact version %s "
"(version %s) due to an integrity error. "
"Retrying in %f seconds.",
artifact_schema.name,
artifact_version.version,
sleep_duration,
)
time.sleep(sleep_duration)
else:
break
else:
# An explicit version was specified for the artifact version.
# We don't do any incrementing and fail immediately if the
# version already exists.
with Session(self.engine) as session:
try:
artifact_version_schema = (
ArtifactVersionSchema.from_request(artifact_version)
)
session.add(artifact_version_schema)
session.commit()
session.refresh(artifact_version_schema)
except IntegrityError:
raise EntityExistsError(
f"Unable to create artifact version "
f"{artifact_schema.name}({artifact_version.version}: "
"An artifact with the same name and version already "
"exists."
f"{artifact_schema.name} (version "
f"{artifact_version.version}): An artifact with the "
"same name and version already exists."
)

with Session(self.engine) as session:
# Save visualizations of the artifact
if artifact_version.visualizations:
for vis in artifact_version.visualizations:
Expand Down Expand Up @@ -2863,7 +2879,11 @@ def create_artifact_version(
session.add(run_metadata_schema)

session.commit()
session.refresh(artifact_version_schema)
artifact_version_schema = session.exec(
select(ArtifactVersionSchema).where(
ArtifactVersionSchema.id == artifact_version_schema.id
)
).one()

return artifact_version_schema.to_model(
include_metadata=True, include_resources=True
Expand Down
4 changes: 0 additions & 4 deletions tests/integration/functional/artifacts/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,10 +365,6 @@ def test_parallel_artifact_creation(clean_client: Client):
name="meaning_of_life", size=min(1000, process_count * 10)
)
assert len(avs) == process_count
print(
{str(i) for i in range(1, process_count + 1)}
- {av.version for av in avs}
)
assert {av.version for av in avs} == {
str(i) for i in range(1, process_count + 1)
}
Expand Down

0 comments on commit f809f02

Please sign in to comment.