Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 66 additions & 4 deletions src/lakebench/benchmarks/elt_bench/elt_bench.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional
from ..base import BaseBenchmark
from ...utils.query_utils import transpile_and_qualify_query, get_table_name_from_ddl

from .engine_impl.spark import SparkELTBench
from .engine_impl.duckdb import DuckDBELTBench
Expand All @@ -14,6 +15,8 @@
from ...engines.polars import Polars
from ...engines.sail import Sail

from ..tpcds.tpcds import TPCDS
import importlib.resources
import posixpath


Expand Down Expand Up @@ -115,6 +118,60 @@ def run(self, mode: str = 'light'):
raise NotImplementedError("Full mode is not implemented yet.")
else:
raise ValueError(f"Mode '{mode}' is not supported. Supported modes: {self.MODE_REGISTRY}.")

def _prepare_schema(self, tables: list[str]):


self.engine.create_schema_if_not_exists(drop_before_create=True)

engine_class_name = self.engine.__class__.__name__.lower()
parent_class_name = self.engine.__class__.__bases__[0].__name__.lower()
benchmark_name = 'tpcds'
engine_root_lib_name = self.engine.__class__.__module__.split('.')[0]
from_dialect = self.engine.SQLGLOT_DIALECT
self.DDL_FILE_NAME = TPCDS.DDL_FILE_NAME

try:
# Try to load engine-specific query first
with importlib.resources.path(
f"{engine_root_lib_name}.benchmarks.{benchmark_name}.resources.ddl.{engine_class_name}",
self.DDL_FILE_NAME
) as ddl_path:
with open(ddl_path, 'r') as ddl_file:
ddl = ddl_file.read()
except (ModuleNotFoundError, FileNotFoundError):
# Try parent engine class name if engine-specific fails
try:
with importlib.resources.path(
f"lakebench.benchmarks.{benchmark_name}.resources.ddl.{parent_class_name}",
self.DDL_FILE_NAME
) as ddl_path:
with open(ddl_path, 'r') as ddl_file:
ddl = ddl_file.read()
except (ModuleNotFoundError, FileNotFoundError):
# Fall back to canonical query
with importlib.resources.path(
f"lakebench.benchmarks.{benchmark_name}.resources.ddl.canonical",
self.DDL_FILE_NAME
) as ddl_path:
with open(ddl_path, 'r') as ddl_file:
ddl = ddl_file.read()
from_dialect = 'spark'

statements = [s for s in ddl.split(';') if len(s) > 7]
for statement in statements:
prepped_ddl = transpile_and_qualify_query(
query=statement,
from_dialect=from_dialect,
to_dialect=self.engine.SQLGLOT_DIALECT,
catalog=getattr(self.engine, 'catalog_name', None),
schema=getattr(self.engine, 'schema_name', None)
)
table_name = get_table_name_from_ddl(prepped_ddl)
# only create tables that are in the specified list
if table_name in tables:
self.engine._create_empty_table(table_name=table_name, ddl=prepped_ddl)


def run_light_mode(self):
"""
Expand All @@ -128,15 +185,20 @@ def run_light_mode(self):
----------
None
"""
tables = [
'store_sales', 'date_dim', 'store', 'item', 'customer'
]

self.mode = 'light'
self.engine.create_schema_if_not_exists(drop_before_create=True)

for table_name in ('store_sales', 'date_dim', 'store', 'item', 'customer'):
if self.engine.SUPPORTS_SCHEMA_PREP:
self._prepare_schema(tables=tables)

for table_name in tables:
with self.timer(phase="Read parquet, write delta (x5)", test_item=table_name, engine=self.engine) as tc:
tc.execution_telemetry = self.engine.load_parquet_to_delta(
parquet_folder_uri=posixpath.join(self.input_parquet_folder_uri, f"{table_name}/"),
table_name=table_name,
table_is_precreated=False,
table_is_precreated=True,
context_decorator=tc.context_decorator
)
with self.timer(phase="Create fact table", test_item='total_sales_fact', engine=self.engine):
Expand Down
2 changes: 1 addition & 1 deletion src/lakebench/engines/fabric_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
)

self.version: str = f"{self.spark.sparkContext.version} (vhd_name=={self.spark.conf.get('spark.synapse.vhd.name')})"
self.cost_per_vcore_hour = cost_per_vcore_hour or getattr(self, '_FABRIC_USD_COST_PER_VCORE_HOUR', None)
self.cost_per_vcore_hour = cost_per_vcore_hour or getattr(self, '_autocalc_usd_cost_per_vcore_hour', None)
self.cost_per_hour = self.get_total_cores() * self.cost_per_vcore_hour

url = self.spark.sparkContext.uiWebUrl
Expand Down