Skip to content

Commit

Permalink
[td] Support writing SELECT statements only (mage-ai#2220)
Browse files Browse the repository at this point in the history
* [td] Support writing SELECT statements only

* add docs
  • Loading branch information
tommydangerous authored Mar 17, 2023
1 parent 01e581d commit bf70019
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 33 deletions.
16 changes: 15 additions & 1 deletion docs/guides/sql-blocks.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,21 @@ Then, it’ll insert a single row into that table.

#### Required SQL statements

When writing raw SQL, you must at least have an `INSERT` statement. The `CREATE TABLE` statement is optional.
When writing raw SQL, you must at least 1 of the following statements:

- `SELECT`
```sql
SELECT 1;
```
- `INSERT`
```sql
INSERT INTO some_table
SELECT 1;
```
- `CREATE TABLE`
```sql
CREATE TABLE some_table (id BIGINT);
```

#### Multiple SQL statements

Expand Down
29 changes: 8 additions & 21 deletions mage_ai/data_preparation/models/block/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import json
import os
import pandas as pd
import re
import simplejson
import sys
import time
Expand Down Expand Up @@ -353,34 +352,22 @@ def table_name(self) -> str:
@property
def full_table_name(self) -> str:
from mage_ai.data_preparation.models.block.sql.utils.shared import (
extract_and_replace_text_between_strings,
extract_create_statement_table_name,
extract_insert_statement_table_names,
)

if not self.content:
return None

statement_partial, _ = extract_and_replace_text_between_strings(
self.content,
'create',
r'\(',
)

if not statement_partial:
matches = re.findall(
r'insert(?: overwrite)*(?: into)*[\s]+([\w.]+)',
self.content,
re.IGNORECASE,
)
if len(matches) >= 1:
return matches[len(matches) - 1]
else:
return None
table_name = extract_create_statement_table_name(self.content)
if table_name:
return table_name

if not statement_partial:
matches = extract_insert_statement_table_names(self.content)
if len(matches) == 0:
return None

parts = statement_partial[:len(statement_partial) - 1].strip().split(' ')
return parts[-1]
return matches[len(matches) - 1]

@classmethod
def after_create(self, block: 'Block', **kwargs):
Expand Down
33 changes: 25 additions & 8 deletions mage_ai/data_preparation/models/block/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
trino,
)
from mage_ai.data_preparation.models.block.sql.utils.shared import (
has_create_or_insert_statement,
interpolate_vars,
)
from mage_ai.data_preparation.models.constants import BlockType
Expand Down Expand Up @@ -389,11 +390,15 @@ def split_query_string(query_string: str) -> List[str]:
arr = []
for query in queries:
query = query.strip()
if not query:
continue

lines = query.split('\n')
query = '\n'.join(list(filter(lambda x: not x.startswith('--'), lines)))
query = query.strip()
query = re.sub(MAGE_SEMI_COLON, ';', query)

if query:
lines = query.split('\n')
query = '\n'.join(list(filter(lambda x: not x.startswith('--'), lines)))
query = query.strip()
query = re.sub(MAGE_SEMI_COLON, ';', query)
arr.append(query)

return arr
Expand All @@ -409,11 +414,23 @@ def execute_raw_sql(
queries = []
fetch_query_at_indexes = []

for query in split_query_string(query_string):
queries.append(query)
fetch_query_at_indexes.append(False)
has_create_or_insert = has_create_or_insert_statement(query_string)

if should_query:
for query in split_query_string(query_string):
if has_create_or_insert:
queries.append(query)
fetch_query_at_indexes.append(False)
else:
if should_query:
query = f"""SELECT *
FROM (
{query}
) AS {block.table_name}__limit
LIMIT 1000"""
queries.append(query)
fetch_query_at_indexes.append(True)

if should_query and has_create_or_insert:
queries.append(f'SELECT * FROM {block.full_table_name} LIMIT 1000')
fetch_query_at_indexes.append(block.full_table_name)

Expand Down
52 changes: 49 additions & 3 deletions mage_ai/data_preparation/models/block/sql/utils/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,33 @@ def __replace_func(db, schema, tn):
for idx, upstream_block in enumerate(block.upstream_blocks):
matcher1 = '{} df_{} {}'.format('{{', idx + 1, '}}')

if BlockLanguage.SQL == upstream_block.type:
is_sql = BlockLanguage.SQL == upstream_block.language
if is_sql:
configuration = upstream_block.configuration
else:
configuration = block.configuration
use_raw_sql = configuration.get('use_raw_sql')

database = configuration.get('data_provider_database', '')
schema = configuration.get('data_provider_schema', '')

replace_with = __replace_func(database, schema, upstream_block.table_name)
upstream_block_content = upstream_block.content
if is_sql and use_raw_sql and not has_create_or_insert_statement(upstream_block_content):
upstream_query = interpolate_input(upstream_block, upstream_block_content)
replace_with = f"""(
{upstream_query}
) AS {upstream_block.table_name}"""

query = re.sub(
'{}[ ]*df_{}[ ]*{}'.format(r'\{\{', idx + 1, r'\}\}'),
__replace_func(database, schema, upstream_block.table_name),
replace_with,
query,
)

query = query.replace(
f'{matcher1}',
__replace_func(database, schema, upstream_block.table_name),
replace_with,
)

return query
Expand Down Expand Up @@ -170,3 +180,39 @@ def extract_and_replace_text_between_strings(
new_text = text[0:max(start_idx - 1, 0)] + replace_string + text[end_idx + 1:]

return extracted_text, new_text


def remove_comments(text: str) -> str:
lines = text.split('\n')
return '\n'.join(line for line in lines if not line.startswith('--'))


def extract_create_statement_table_name(text: str) -> str:
statement_partial, _ = extract_and_replace_text_between_strings(
remove_comments(text),
r'create table(?: if not exists)*',
r'\(',
)
if not statement_partial:
return None

parts = statement_partial[:len(statement_partial) - 1].strip().split(' ')
return parts[-1]


def extract_insert_statement_table_names(text: str) -> List[str]:
matches = re.findall(
r'insert(?: overwrite)*(?: into)*[\s]+([\w.]+)',
remove_comments(text),
re.IGNORECASE,
)
return matches


def has_create_or_insert_statement(text: str) -> bool:
table_name = extract_create_statement_table_name(text)
if table_name:
return True

matches = extract_insert_statement_table_names(text)
return len(matches) >= 1

0 comments on commit bf70019

Please sign in to comment.