Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 126 additions & 1 deletion superset/db_engine_specs/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
# specific language governing permissions and limitations
# under the License.
import contextlib
import logging
import re
from datetime import datetime
from decimal import Decimal
from re import Pattern
from typing import Any, Callable, Optional
from urllib import parse

import pandas as pd
from flask_babel import gettext as __
from sqlalchemy import types
from sqlalchemy import Integer, types
from sqlalchemy.dialects.mysql import (
BIT,
DECIMAL,
Expand All @@ -45,9 +47,13 @@
DatabaseCategory,
)
from superset.errors import SupersetErrorType
from superset.models.core import Database
from superset.models.sql_lab import Query
from superset.sql.parse import Table
from superset.utils.core import GenericDataType

logger = logging.getLogger(__name__)

# Regular expressions to catch custom errors
CONNECTION_ACCESS_DENIED_REGEX = re.compile(
"Access denied for user '(?P<username>.*?)'@'(?P<hostname>.*?)'"
Expand Down Expand Up @@ -401,3 +407,122 @@ def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool:
return False

return True

@classmethod
def _requires_primary_key(cls, engine: Any) -> bool:
"""
Check if the MySQL database requires a primary key for table creation.

:param engine: SQLAlchemy engine
:return: True if primary key is required
"""
try:
with engine.connect() as conn:
result = conn.execute(
"SELECT @@session.sql_require_primary_key"
).scalar()
return bool(result)
except Exception: # pylint: disable=broad-except
# If we can't determine the setting, assume it's not required
# to maintain backward compatibility
return False

@classmethod
def df_to_sql(
cls,
database: Database,
table: Table,
df: pd.DataFrame,
to_sql_kwargs: dict[str, Any],
) -> None:
"""
Upload data from a Pandas DataFrame to a MySQL database.

Automatically adds a primary key column when the database requires it
(e.g., sql_require_primary_key = ON) and the table is being created.

:param database: The database to upload the data to
:param table: The table to upload the data to
:param df: The dataframe with data to be uploaded
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql
"""
with cls.get_engine(
database,
catalog=table.catalog,
schema=table.schema,
) as engine:
# Check if we need to add a primary key
if_exists = to_sql_kwargs.get("if_exists", "fail")
needs_primary_key = (
if_exists == "fail" # Only for new table creation
and not to_sql_kwargs.get("index", False) # No index column exists
Comment on lines +456 to +458
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unclosed parenthesis in conditional expression

Line 456 has an unclosed opening parenthesis ( in the needs_primary_key assignment. The condition started on line 456 is missing a closing parenthesis before the assignment completes on line 458.

Code suggestion
Check the AI-generated fix before applying
Suggested change
needs_primary_key = (
if_exists == "fail" # Only for new table creation
and not to_sql_kwargs.get("index", False) # No index column exists
needs_primary_key = (
if_exists == "fail" # Only for new table creation
and not to_sql_kwargs.get("index", False) # No index column exists
)

Code Review Run #528e4b


Should Bito avoid suggestions like this for future reviews? (Manage Rules)

  • Yes, avoid them

if needs_primary_key:
# Add an auto-incrementing primary key column
pk_column_name = "__superset_upload_id__"
# Ensure the column name doesn't conflict with existing columns
while pk_column_name in df.columns:
pk_column_name = f"_{pk_column_name}"

# Create a copy of the dataframe with the primary key column
df_with_pk = df.copy()
df_with_pk.insert(0, pk_column_name, range(1, len(df) + 1))

# Use pandas to create table, then alter it to add PRIMARY KEY
# This is a two-step process because pandas doesn't support
# PRIMARY KEY or AUTO_INCREMENT in dtype specifications

to_sql_kwargs_temp = {**to_sql_kwargs}
to_sql_kwargs_temp["name"] = table.table
if table.schema:
to_sql_kwargs_temp["schema"] = table.schema

if (
engine.dialect.supports_multivalues_insert
or cls.supports_multivalues_insert
):
to_sql_kwargs_temp["method"] = "multi"

logger.info(
"Adding primary key column '%s' for CSV upload to %s.%s "
"(sql_require_primary_key enabled)",
pk_column_name,
table.schema or "default",
table.table,
)

# Write data with pandas
df_with_pk.to_sql(con=engine, **to_sql_kwargs_temp)

# Now alter the table to add PRIMARY KEY and AUTO_INCREMENT
full_table_name = (
f"`{table.schema}`.`{table.table}`"
if table.schema
else f"`{table.table}`"
)

try:
with engine.begin() as conn: # Use transaction
# Add AUTO_INCREMENT and PRIMARY KEY
alter_sql = (
f"ALTER TABLE {full_table_name} "
f"MODIFY COLUMN `{pk_column_name}` INTEGER AUTO_INCREMENT, "
f"ADD PRIMARY KEY (`{pk_column_name}`)"
)
conn.execute(alter_sql)
except Exception as ex:
logger.error(
"Failed to add PRIMARY KEY constraint to %s: %s",
full_table_name,
str(ex),
)
# Clean up the table if ALTER failed
try:
with engine.begin() as conn:
conn.execute(f"DROP TABLE IF EXISTS {full_table_name}")

return

# Call parent implementation for normal case
super(MySQLEngineSpec, cls).df_to_sql(
database, table, df, to_sql_kwargs
)
236 changes: 236 additions & 0 deletions tests/integration_tests/databases/commands/upload_mysql_pk_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Tests for CSV upload with MySQL databases requiring primary keys.
"""

from __future__ import annotations

from unittest.mock import MagicMock, Mock, patch

import pandas as pd
import pytest

from superset.commands.database.uploaders.base import UploadCommand
from superset.commands.database.uploaders.csv_reader import CSVReader
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.models.core import Database
from superset.sql.parse import Table


def test_mysql_requires_primary_key_detection():
"""Test that _requires_primary_key correctly detects MySQL setting"""
# Create a mock engine
mock_engine = Mock()
mock_conn = Mock()
mock_result = Mock()

# Test case: sql_require_primary_key is ON (1)
mock_result.scalar.return_value = 1
mock_conn.execute.return_value = mock_result
mock_conn.__enter__ = Mock(return_value=mock_conn)
mock_conn.__exit__ = Mock(return_value=False)
mock_engine.connect.return_value = mock_conn

assert MySQLEngineSpec._requires_primary_key(mock_engine) is True

# Test case: sql_require_primary_key is OFF (0)
mock_result.scalar.return_value = 0
assert MySQLEngineSpec._requires_primary_key(mock_engine) is False

# Test case: Query fails (e.g., older MySQL version without the setting)
mock_conn.execute.side_effect = Exception("Unknown system variable")
assert MySQLEngineSpec._requires_primary_key(mock_engine) is False


def test_mysql_df_to_sql_adds_primary_key():
"""Test that df_to_sql adds a primary key when required"""
# Create test data
df = pd.DataFrame({
"name": ["Alice", "Bob", "Charlie"],
"age": [30, 25, 35],
"city": ["NYC", "LA", "SF"]
})

# Create mock database and table
mock_database = Mock(spec=Database)
mock_table = Table(table="test_table", schema=None)

# Mock engine with sql_require_primary_key enabled
mock_engine = Mock()
mock_conn = Mock()
mock_result = Mock()
mock_result.scalar.return_value = 1 # Primary key required
mock_conn.execute.return_value = mock_result
mock_conn.commit = Mock()
mock_conn.__enter__ = Mock(return_value=mock_conn)
mock_conn.__exit__ = Mock(return_value=False)
mock_engine.connect.return_value = mock_conn

# Mock engine.begin() for transaction context
mock_transaction = Mock()
mock_transaction.__enter__ = Mock(return_value=mock_conn)
mock_transaction.__exit__ = Mock(return_value=False)
mock_engine.begin.return_value = mock_transaction

mock_engine.dialect.supports_multivalues_insert = True

# Track to_sql calls
to_sql_called = False
captured_df = None
captured_kwargs = None

def capture_to_sql(con, **kwargs):
nonlocal to_sql_called, captured_df, captured_kwargs
to_sql_called = True
captured_df = kwargs.get("name") # Just capture what we can
captured_kwargs = kwargs

# Mock df.to_sql to capture the call
with patch.object(pd.DataFrame, "to_sql", side_effect=capture_to_sql):
with patch.object(
MySQLEngineSpec,
"get_engine",
return_value=mock_engine.__enter__()
):
to_sql_kwargs = {
"if_exists": "fail",
"index": False,
"chunksize": 1000,
}

MySQLEngineSpec.df_to_sql(
mock_database,
mock_table,
df,
to_sql_kwargs
)

# Verify to_sql was called
assert to_sql_called
assert "test_table" in str(captured_kwargs)

# Verify ALTER TABLE was called to add PRIMARY KEY
execute_calls = [call for call in mock_conn.execute.call_args_list]
# First call is checking sql_require_primary_key
# Second call should be ALTER TABLE
assert len(execute_calls) >= 2
alter_call = execute_calls[1][0][0]
assert "ALTER TABLE" in alter_call
assert "AUTO_INCREMENT" in alter_call
assert "PRIMARY KEY" in alter_call


def test_mysql_df_to_sql_skips_primary_key_when_not_required():
"""Test that df_to_sql doesn't add primary key when not required"""
df = pd.DataFrame({
"name": ["Alice", "Bob"],
"age": [30, 25]
})

mock_database = Mock(spec=Database)
mock_table = Table(table="test_table", schema=None)

# Mock engine with sql_require_primary_key disabled
mock_engine = Mock()
mock_conn = Mock()
mock_result = Mock()
mock_result.scalar.return_value = 0 # Primary key NOT required
mock_conn.execute.return_value = mock_result
mock_conn.__enter__ = Mock(return_value=mock_conn)
mock_conn.__exit__ = Mock(return_value=False)
mock_engine.connect.return_value = mock_conn
mock_engine.dialect.supports_multivalues_insert = True

# Use parent's df_to_sql
with patch.object(
MySQLEngineSpec,
"get_engine",
return_value=mock_engine.__enter__()
):
with patch("superset.db_engine_specs.base.BaseEngineSpec.df_to_sql") as mock_parent:
to_sql_kwargs = {
"if_exists": "fail",
"index": False,
}

MySQLEngineSpec.df_to_sql(
mock_database,
mock_table,
df,
to_sql_kwargs
)

# Verify parent's df_to_sql was called (normal path)
mock_parent.assert_called_once()
# Verify no ALTER TABLE was executed
# Only one call should be to check sql_require_primary_key
assert mock_conn.execute.call_count == 1


def test_mysql_df_to_sql_handles_column_name_conflicts():
"""Test that primary key column name is adjusted if it conflicts"""
df = pd.DataFrame({
"__superset_upload_id__": ["existing_data"],
"name": ["Alice"]
})

mock_database = Mock(spec=Database)
mock_table = Table(table="test_table", schema=None)

mock_engine = Mock()
mock_conn = Mock()
mock_result = Mock()
mock_result.scalar.return_value = 1 # Primary key required
mock_conn.execute.return_value = mock_result
mock_conn.commit = Mock()
mock_conn.__enter__ = Mock(return_value=mock_conn)
mock_conn.__exit__ = Mock(return_value=False)
mock_engine.connect.return_value = mock_conn
mock_engine.dialect.supports_multivalues_insert = True

captured_df_columns = None

def capture_to_sql(con, **kwargs):
# Can't directly access df from kwargs in mock, but we can check ALTER
pass

with patch.object(pd.DataFrame, "to_sql", side_effect=capture_to_sql):
with patch.object(
MySQLEngineSpec,
"get_engine",
return_value=mock_engine.__enter__()
):
to_sql_kwargs = {
"if_exists": "fail",
"index": False,
}

MySQLEngineSpec.df_to_sql(
mock_database,
mock_table,
df,
to_sql_kwargs
)

# Verify ALTER TABLE uses a different column name (prefixed with _)
execute_calls = [call for call in mock_conn.execute.call_args_list]
if len(execute_calls) >= 2:
alter_call = execute_calls[1][0][0]
# The column name should be modified to avoid conflict
assert "__superset_upload_id__" in alter_call or "___superset_upload_id__" in alter_call
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test assertion too permissive

The assertion allows either the original or prefixed column name, but the code always prefixes when there's a conflict, so the test should verify the exact adjusted name.

Code suggestion
Check the AI-generated fix before applying
 -        assert "___superset_upload_id__" in alter_call
 +        assert "___superset_upload_id__" in alter_call and "__superset_upload_id__" not in alter_call

Code Review Run #528e4b


Should Bito avoid suggestions like this for future reviews? (Manage Rules)

  • Yes, avoid them

Loading