Skip to content

Commit

Permalink
SNOW-1507519: Add support for column_order='name' in append mode writ…
Browse files Browse the repository at this point in the history
…e.save_as_table (snowflakedb#1872)
  • Loading branch information
sfc-gh-aling authored Jul 12, 2024
1 parent 63cc6a0 commit 60345ed
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 53 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,20 @@
### Snowpark pandas API Updates

#### New Features

- Added partial support for `Series.str.translate` where the values in the `table` are single-codepoint strings.
- Added support for `DataFrame.corr`.
- Added support for `limit` parameter when `method` parameter is used in `fillna`.

#### Bug Fixes
- Fixed an issue when using np.where and df.where when the scalar 'other' is the literal 0.

### Snowpark Local Testing Updates

### New Features

- Added support for the `column_order` parameter to method `DataFrameWriter.save_as_table`.

## 1.19.0 (2024-06-25)

### Snowpark Python API Updates
Expand Down
68 changes: 52 additions & 16 deletions src/snowflake/snowpark/mock/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,11 @@ def read_table(self, name: Union[str, Iterable[str]]) -> TableEmulator:
)

def write_table(
self, name: Union[str, Iterable[str]], table: TableEmulator, mode: SaveMode
self,
name: Union[str, Iterable[str]],
table: TableEmulator,
mode: SaveMode,
column_names: Optional[List[str]] = None,
) -> Row:
for column in table.columns:
if not table[column].sf_type.nullable and table[column].isnull().any():
Expand All @@ -133,21 +137,49 @@ def write_table(
name = get_fully_qualified_name(name, current_schema, current_database)
table = copy(table)
if mode == SaveMode.APPEND:
# Fix append by index
if name in self.table_registry:
target_table = self.table_registry[name]
input_schema = table.columns.to_list()
existing_schema = target_table.columns.to_list()

if len(table.columns.to_list()) != len(
target_table.columns.to_list()
):
raise SnowparkLocalTestingException(
f"Cannot append because incoming data has different schema {table.columns.to_list()} than existing table { target_table.columns.to_list()}"
)
if not column_names: # append with column_order being index
if len(input_schema) != len(existing_schema):
raise SnowparkLocalTestingException(
f"Cannot append because incoming data has different schema {input_schema} than existing table {existing_schema}"
)
# temporarily align the column names of both dataframe to be col indexes 0, 1, ... N - 1
table.columns = range(table.shape[1])
target_table.columns = range(target_table.shape[1])
else: # append with column_order being name
if invalid_cols := set(input_schema) - set(existing_schema):
identifiers = "', '".join(
unquote_if_quoted(id) for id in invalid_cols
)
raise SnowparkLocalTestingException(
f"table contains invalid identifier '{identifiers}'"
)
invalid_non_nullable_cols = []
for missing_col in set(existing_schema) - set(input_schema):
if target_table[missing_col].sf_type.nullable:
table[missing_col] = None
table.sf_types[missing_col] = target_table[
missing_col
].sf_type
else:
invalid_non_nullable_cols.append(missing_col)
if invalid_non_nullable_cols:
identifiers = "', '".join(
unquote_if_quoted(id)
for id in invalid_non_nullable_cols
)
raise SnowparkLocalTestingException(
f"NULL result in a non-nullable column '{identifiers}'"
)

table.columns = target_table.columns
self.table_registry[name] = pandas.concat(
[target_table, table], ignore_index=True
)
self.table_registry[name].columns = existing_schema
self.table_registry[name].sf_types = target_table.sf_types
else:
self.table_registry[name] = table
Expand Down Expand Up @@ -571,10 +603,12 @@ def execute(
)
row = row_struct(
*[
Decimal("{0:.{1}f}".format(v, sf_types[i].datatype.scale))
if isinstance(sf_types[i].datatype, DecimalType)
and v is not None
else v
(
Decimal("{0:.{1}f}".format(v, sf_types[i].datatype.scale))
if isinstance(sf_types[i].datatype, DecimalType)
and v is not None
else v
)
for i, v in enumerate(pdr)
]
)
Expand Down Expand Up @@ -636,9 +670,11 @@ def get_result_and_metadata(
attrs = [
Attribute(
name=quote_name(column_name.strip()),
datatype=column_data.sf_type
if column_data.sf_type
else res.sf_types[column_name],
datatype=(
column_data.sf_type
if column_data.sf_type
else res.sf_types[column_name]
),
)
for column_name, column_data in res.items()
]
Expand Down
20 changes: 9 additions & 11 deletions src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,16 +1076,12 @@ def outer_join(base_df):
if isinstance(source_plan, MockFileOperation):
return execute_file_operation(source_plan, analyzer)
if isinstance(source_plan, SnowflakeCreateTable):
if source_plan.column_names is not None:
analyzer.session._conn.log_not_supported_error(
external_feature_name="Inserting data into table by matching columns",
internal_feature_name=type(source_plan).__name__,
parameters_info={"source_plan.column_names": "True"},
raise_error=NotImplementedError,
)
res_df = execute_mock_plan(source_plan.query, expr_to_alias)
return entity_registry.write_table(
source_plan.table_name, res_df, source_plan.mode
source_plan.table_name,
res_df,
source_plan.mode,
column_names=source_plan.column_names,
)
if isinstance(source_plan, UnresolvedRelation):
entity_name = source_plan.name
Expand Down Expand Up @@ -1122,9 +1118,11 @@ def outer_join(base_df):
)

return res_df.sample(
n=None
if source_plan.row_count is None
else min(source_plan.row_count, len(res_df)),
n=(
None
if source_plan.row_count is None
else min(source_plan.row_count, len(res_df))
),
frac=source_plan.probability_fraction,
random_state=source_plan.seed,
)
Expand Down
151 changes: 125 additions & 26 deletions tests/integ/scala/test_dataframe_writer_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

import pytest

import snowflake.connector.errors
from snowflake.snowpark import Row
from snowflake.snowpark._internal.utils import TempObjectType, parse_table_name
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.functions import col, parse_json
from snowflake.snowpark.mock.exceptions import SnowparkLocalTestingException
from snowflake.snowpark.types import (
DoubleType,
IntegerType,
Expand All @@ -20,10 +22,6 @@
from tests.utils import TestFiles, Utils


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="FEAT: support truncate and column_order in save as table",
)
def test_write_with_target_column_name_order(session, local_testing_mode):
table_name = Utils.random_table_name()
empty_df = session.create_dataframe(
Expand All @@ -41,7 +39,7 @@ def test_write_with_target_column_name_order(session, local_testing_mode):

# By default, it is by index
df1.write.save_as_table(table_name, mode="append", table_type="temp")
Utils.check_answer(session.table(table_name), [Row(1, 2)])
Utils.check_answer(session.table(table_name), [Row(**{"A": 1, "B": 2})])

# Explicitly use "index"
empty_df.write.save_as_table(
Expand All @@ -50,7 +48,7 @@ def test_write_with_target_column_name_order(session, local_testing_mode):
df1.write.save_as_table(
table_name, mode="append", column_order="index", table_type="temp"
)
Utils.check_answer(session.table(table_name), [Row(1, 2)])
Utils.check_answer(session.table(table_name), [Row(**{"A": 1, "B": 2})])

# use order by "name"
empty_df.write.save_as_table(
Expand All @@ -59,28 +57,34 @@ def test_write_with_target_column_name_order(session, local_testing_mode):
df1.write.save_as_table(
table_name, mode="append", column_order="name", table_type="temp"
)
Utils.check_answer(session.table(table_name), [Row(2, 1)])
Utils.check_answer(session.table(table_name), [Row(**{"A": 2, "B": 1})])

# If target table doesn't exists, "order by name" is not actually used.
# If target table doesn't exist, "order by name" is not actually used.
Utils.drop_table(session, table_name)
df1.write.saveAsTable(table_name, mode="append", column_order="name")
Utils.check_answer(session.table(table_name), [Row(1, 2)])
df1.write.save_as_table(table_name, mode="append", column_order="name")
# NOTE: Order is different in the below check
# because the table returns columns in the order of the order of the schema `df1`
Utils.check_answer(session.table(table_name), [Row(**{"B": 1, "A": 2})])
finally:
session.table(table_name).drop_table()

# column name and table name with special characters
special_table_name = '"test table name"'
Utils.create_table(
session, special_table_name, '"a a" int, "b b" int', is_temporary=True
)
try:
df2 = session.create_dataframe([(1, 2)]).to_df("b b", "a a")
df2.write.save_as_table(
special_table_name, mode="append", column_order="name", table_type="temp"
if not local_testing_mode:
# column name and table name with special characters
special_table_name = '"test table name"'
Utils.create_table(
session, special_table_name, '"a a" int, "b b" int', is_temporary=True
)
Utils.check_answer(session.table(special_table_name), [Row(2, 1)])
finally:
Utils.drop_table(session, special_table_name)
try:
df2 = session.create_dataframe([(1, 2)]).to_df("b b", "a a")
df2.write.save_as_table(
special_table_name,
mode="append",
column_order="name",
table_type="temp",
)
Utils.check_answer(session.table(special_table_name), [Row(2, 1)])
finally:
Utils.drop_table(session, special_table_name)


@pytest.mark.xfail(
Expand All @@ -105,10 +109,6 @@ def test_write_with_target_table_autoincrement(
Utils.drop_table(session, table_name)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="FEAT: Inserting data into table by matching columns is not supported",
)
def test_negative_write_with_target_column_name_order(session):
table_name = Utils.random_table_name()
session.create_dataframe(
Expand Down Expand Up @@ -139,6 +139,105 @@ def test_negative_write_with_target_column_name_order(session):
session.table(table_name).drop_table()


def test_write_with_target_column_name_order_all_kinds_of_dataframes_without_truncates(
session,
):
table_name = Utils.random_table_name()

session.create_dataframe(
[],
schema=StructType(
[StructField("a", IntegerType()), StructField("b", IntegerType())]
),
).write.save_as_table(table_name, table_type="temporary")

try:
large_df = session.create_dataframe([[1, 2]] * 1024, schema=["b", "a"])
large_df.write.save_as_table(
table_name, mode="append", column_order="name", table_type="temp"
)
rows = session.table(table_name).collect()
assert len(rows) == 1024
for row in rows:
assert row["B"] == 1 and row["A"] == 2
finally:
session.table(table_name).drop_table()


def test_write_with_target_column_name_order_with_nullable_column(
session, local_testing_mode
):
table_name, non_nullable_table_name = (
Utils.random_table_name(),
Utils.random_table_name(),
)

session.create_dataframe(
[],
schema=StructType(
[
StructField("a", IntegerType()),
StructField("b", IntegerType()),
StructField("c", StringType(), nullable=True),
StructField("d", StringType(), nullable=True),
]
),
).write.save_as_table(table_name, table_type="temporary")

session.create_dataframe(
[],
schema=StructType(
[
StructField("a", IntegerType()),
StructField("b", StringType(), nullable=False),
]
),
).write.save_as_table(non_nullable_table_name, table_type="temporary")
try:
df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["b", "a"])

df1.write.save_as_table(
table_name, mode="append", table_type="temp", column_order="name"
)
Utils.check_answer(
session.table(table_name),
[
Row(
**{
"A": 2,
"B": 1,
"C": None,
"D": None,
}
),
Row(
**{
"A": 4,
"B": 3,
"C": None,
"D": None,
}
),
],
)

df2 = session.create_dataframe([[1], [2]], schema=["a"])
with pytest.raises(
SnowparkLocalTestingException
if local_testing_mode
else snowflake.connector.errors.IntegrityError
):
df2.write.save_as_table(
non_nullable_table_name,
mode="append",
table_type="temp",
column_order="name",
)
finally:
session.table(table_name).drop_table()
session.table(non_nullable_table_name).drop_table()


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="FEAT: Inserting data into table by matching columns is not supported",
Expand Down

0 comments on commit 60345ed

Please sign in to comment.