Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions superset/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import json
import logging
import os
from typing import Any, Callable, Optional

Expand Down Expand Up @@ -45,11 +46,17 @@ def get_sqla_class() -> Any:
from superset.extensions.stats_logger import BaseStatsLoggerManager
from superset.security.manager import SupersetSecurityManager
from superset.utils.cache_manager import CacheManager
from superset.utils.database import apply_mariadb_ddl_fix
from superset.utils.encrypt import EncryptedFieldFactory
from superset.utils.feature_flag_manager import FeatureFlagManager
from superset.utils.machine_auth import MachineAuthProviderFactory
from superset.utils.profiler import SupersetProfiler

# Apply MariaDB DDL fix early in the import chain
try:
apply_mariadb_ddl_fix()
except Exception as ex:
logging.exception("Applying MariaDB DDL fix failed; continuing without patch: %s", ex)

class ResultsBackendManager:
def __init__(self) -> None:
Expand Down
22 changes: 21 additions & 1 deletion superset/utils/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING
from typing import Any, TYPE_CHECKING

from flask import current_app as app
from sqlalchemy.sql import compiler

from superset.constants import EXAMPLES_DB_UUID

Expand Down Expand Up @@ -84,3 +85,22 @@ def remove_database(database: Database) -> None:

db.session.delete(database)
db.session.flush()


def apply_mariadb_ddl_fix() -> None:
"""
Fix MariaDB "NO CYCLE" syntax issue - MariaDB uses "NOCYCLE" (no space).

This fix will be included in SQLAlchemy v2.1.0.
See: https://github.com/sqlalchemy/sqlalchemy/blob/rel_2_1_0b1/lib/sqlalchemy/dialects/mysql/_mariadb_shim.py
"""
original_visit_create_sequence = compiler.DDLCompiler.visit_create_sequence

def patched_visit_create_sequence(self: Any, create: Any, **kw: Any) -> str:
text = original_visit_create_sequence(self, create, **kw)
dialect_name = getattr(self.dialect, "name", "") or ""
if "mariadb" in dialect_name.lower():
return text.replace("NO CYCLE", "NOCYCLE")
return text

compiler.DDLCompiler.visit_create_sequence = patched_visit_create_sequence
53 changes: 53 additions & 0 deletions tests/unit_tests/utils/test_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for superset.utils.database module."""

import pytest
from sqlalchemy import Sequence
from sqlalchemy.dialects import mysql, postgresql
from sqlalchemy.schema import CreateSequence
from sqlalchemy.sql.compiler import DDLCompiler

from superset.utils.database import apply_mariadb_ddl_fix


@pytest.fixture(scope="module", autouse=True)
def setup_mariadb_ddl_fix():
"""Apply MariaDB DDL fix once per module before tests run."""
apply_mariadb_ddl_fix()


def test_mariadb_nocycle_fix_applied():
"""Test that 'NO CYCLE' is replaced with 'NOCYCLE' for MariaDB dialect."""
dialect = mysql.dialect()
dialect.name = "mariadb"
ddl_compiler = DDLCompiler(dialect, None)
seq = Sequence("test_seq", cycle=False)

result = ddl_compiler.visit_create_sequence(CreateSequence(seq))
assert "NOCYCLE" in result
assert "NO CYCLE" not in result


def test_nocycle_fix_not_applied_for_postgresql():
"""Test that 'NO CYCLE' is NOT replaced for PostgreSQL dialect."""
dialect = postgresql.dialect()
compiler = DDLCompiler(dialect, None)
seq = Sequence("test_seq", cycle=False)

result = compiler.visit_create_sequence(CreateSequence(seq))
assert "NO CYCLE" in result