From bf70019ebea708681b97df3107fe427e0d4fe926 Mon Sep 17 00:00:00 2001 From: DANGerous Date: Thu, 16 Mar 2023 22:31:47 -0700 Subject: [PATCH] [td] Support writing SELECT statements only (#2220) * [td] Support writing SELECT statements only * add docs --- docs/guides/sql-blocks.mdx | 16 +++++- .../data_preparation/models/block/__init__.py | 29 +++-------- .../models/block/sql/__init__.py | 33 +++++++++--- .../models/block/sql/utils/shared.py | 52 +++++++++++++++++-- 4 files changed, 97 insertions(+), 33 deletions(-) diff --git a/docs/guides/sql-blocks.mdx b/docs/guides/sql-blocks.mdx index 42ae037a82b1..88373431ff08 100644 --- a/docs/guides/sql-blocks.mdx +++ b/docs/guides/sql-blocks.mdx @@ -97,7 +97,21 @@ Then, it’ll insert a single row into that table. #### Required SQL statements -When writing raw SQL, you must at least have an `INSERT` statement. The `CREATE TABLE` statement is optional. +When writing raw SQL, you must at least 1 of the following statements: + +- `SELECT` + ```sql + SELECT 1; + ``` +- `INSERT` + ```sql + INSERT INTO some_table + SELECT 1; + ``` +- `CREATE TABLE` + ```sql + CREATE TABLE some_table (id BIGINT); + ``` #### Multiple SQL statements diff --git a/mage_ai/data_preparation/models/block/__init__.py b/mage_ai/data_preparation/models/block/__init__.py index 76d3a3cb1411..8570828ac00e 100644 --- a/mage_ai/data_preparation/models/block/__init__.py +++ b/mage_ai/data_preparation/models/block/__init__.py @@ -51,7 +51,6 @@ import json import os import pandas as pd -import re import simplejson import sys import time @@ -353,34 +352,22 @@ def table_name(self) -> str: @property def full_table_name(self) -> str: from mage_ai.data_preparation.models.block.sql.utils.shared import ( - extract_and_replace_text_between_strings, + extract_create_statement_table_name, + extract_insert_statement_table_names, ) if not self.content: return None - statement_partial, _ = extract_and_replace_text_between_strings( - self.content, - 'create', - r'\(', - ) - - if not statement_partial: - matches = re.findall( - r'insert(?: overwrite)*(?: into)*[\s]+([\w.]+)', - self.content, - re.IGNORECASE, - ) - if len(matches) >= 1: - return matches[len(matches) - 1] - else: - return None + table_name = extract_create_statement_table_name(self.content) + if table_name: + return table_name - if not statement_partial: + matches = extract_insert_statement_table_names(self.content) + if len(matches) == 0: return None - parts = statement_partial[:len(statement_partial) - 1].strip().split(' ') - return parts[-1] + return matches[len(matches) - 1] @classmethod def after_create(self, block: 'Block', **kwargs): diff --git a/mage_ai/data_preparation/models/block/sql/__init__.py b/mage_ai/data_preparation/models/block/sql/__init__.py index a042e0fd380d..be3cb6abb242 100644 --- a/mage_ai/data_preparation/models/block/sql/__init__.py +++ b/mage_ai/data_preparation/models/block/sql/__init__.py @@ -9,6 +9,7 @@ trino, ) from mage_ai.data_preparation.models.block.sql.utils.shared import ( + has_create_or_insert_statement, interpolate_vars, ) from mage_ai.data_preparation.models.constants import BlockType @@ -389,11 +390,15 @@ def split_query_string(query_string: str) -> List[str]: arr = [] for query in queries: query = query.strip() + if not query: + continue + + lines = query.split('\n') + query = '\n'.join(list(filter(lambda x: not x.startswith('--'), lines))) + query = query.strip() + query = re.sub(MAGE_SEMI_COLON, ';', query) + if query: - lines = query.split('\n') - query = '\n'.join(list(filter(lambda x: not x.startswith('--'), lines))) - query = query.strip() - query = re.sub(MAGE_SEMI_COLON, ';', query) arr.append(query) return arr @@ -409,11 +414,23 @@ def execute_raw_sql( queries = [] fetch_query_at_indexes = [] - for query in split_query_string(query_string): - queries.append(query) - fetch_query_at_indexes.append(False) + has_create_or_insert = has_create_or_insert_statement(query_string) - if should_query: + for query in split_query_string(query_string): + if has_create_or_insert: + queries.append(query) + fetch_query_at_indexes.append(False) + else: + if should_query: + query = f"""SELECT * +FROM ( + {query} +) AS {block.table_name}__limit +LIMIT 1000""" + queries.append(query) + fetch_query_at_indexes.append(True) + + if should_query and has_create_or_insert: queries.append(f'SELECT * FROM {block.full_table_name} LIMIT 1000') fetch_query_at_indexes.append(block.full_table_name) diff --git a/mage_ai/data_preparation/models/block/sql/utils/shared.py b/mage_ai/data_preparation/models/block/sql/utils/shared.py index bae4e833945a..b3eef1023371 100644 --- a/mage_ai/data_preparation/models/block/sql/utils/shared.py +++ b/mage_ai/data_preparation/models/block/sql/utils/shared.py @@ -47,23 +47,33 @@ def __replace_func(db, schema, tn): for idx, upstream_block in enumerate(block.upstream_blocks): matcher1 = '{} df_{} {}'.format('{{', idx + 1, '}}') - if BlockLanguage.SQL == upstream_block.type: + is_sql = BlockLanguage.SQL == upstream_block.language + if is_sql: configuration = upstream_block.configuration else: configuration = block.configuration + use_raw_sql = configuration.get('use_raw_sql') database = configuration.get('data_provider_database', '') schema = configuration.get('data_provider_schema', '') + replace_with = __replace_func(database, schema, upstream_block.table_name) + upstream_block_content = upstream_block.content + if is_sql and use_raw_sql and not has_create_or_insert_statement(upstream_block_content): + upstream_query = interpolate_input(upstream_block, upstream_block_content) + replace_with = f"""( + {upstream_query} +) AS {upstream_block.table_name}""" + query = re.sub( '{}[ ]*df_{}[ ]*{}'.format(r'\{\{', idx + 1, r'\}\}'), - __replace_func(database, schema, upstream_block.table_name), + replace_with, query, ) query = query.replace( f'{matcher1}', - __replace_func(database, schema, upstream_block.table_name), + replace_with, ) return query @@ -170,3 +180,39 @@ def extract_and_replace_text_between_strings( new_text = text[0:max(start_idx - 1, 0)] + replace_string + text[end_idx + 1:] return extracted_text, new_text + + +def remove_comments(text: str) -> str: + lines = text.split('\n') + return '\n'.join(line for line in lines if not line.startswith('--')) + + +def extract_create_statement_table_name(text: str) -> str: + statement_partial, _ = extract_and_replace_text_between_strings( + remove_comments(text), + r'create table(?: if not exists)*', + r'\(', + ) + if not statement_partial: + return None + + parts = statement_partial[:len(statement_partial) - 1].strip().split(' ') + return parts[-1] + + +def extract_insert_statement_table_names(text: str) -> List[str]: + matches = re.findall( + r'insert(?: overwrite)*(?: into)*[\s]+([\w.]+)', + remove_comments(text), + re.IGNORECASE, + ) + return matches + + +def has_create_or_insert_statement(text: str) -> bool: + table_name = extract_create_statement_table_name(text) + if table_name: + return True + + matches = extract_insert_statement_table_names(text) + return len(matches) >= 1