Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add WriteFile command #89

Merged
merged 13 commits into from
Apr 11, 2023
Next Next commit
add write file command
  • Loading branch information
leo-schick committed Mar 8, 2023
commit 0c5eaa0d7ab9441de5bfe0f6b9374e7a9c25929d
2 changes: 2 additions & 0 deletions docs/commands.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Files commands

.. autoclass:: ReadScriptOutput

.. autoclass:: WriteFile


Python commands
---------------
Expand Down
72 changes: 63 additions & 9 deletions mara_pipelines/commands/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pathlib
import shlex
import sys
from typing import List, Tuple, Dict
from typing import List, Tuple, Dict, Union, Callable

import enum

Expand Down Expand Up @@ -83,10 +83,10 @@ def __init__(self, file_name: str, compression: Compression, target_table: str,
self.timezone = timezone
self.file_format = file_format

def db_alias(self):
def db_alias(self) -> str:
return self._db_alias or config.default_db_alias()

def shell_command(self):
def shell_command(self) -> str:
copy_from_stdin_command = mara_db.shell.copy_from_stdin_command(
self.db_alias(), csv_format=self.csv_format, target_table=self.target_table,
skip_header=self.skip_header,
Expand All @@ -104,7 +104,7 @@ def shell_command(self):
# Bigquery loading does not support streaming data through pipes
return copy_from_stdin_command + f" {pathlib.Path(config.data_dir()) / self.file_name}"

def mapper_file_path(self):
def mapper_file_path(self) -> pathlib.Path:
return self.parent.parent.base_path() / self.mapper_script_file_name

def html_doc_items(self) -> List[Tuple[str, str]]:
Expand Down Expand Up @@ -141,10 +141,10 @@ def __init__(self, sqlite_file_name: str, target_table: str,
self.timezone = timezone

@property
def db_alias(self):
def db_alias(self) -> str:
return self._db_alias or config.default_db_alias()

def shell_command(self):
def shell_command(self) -> str:
return (sql._SQLCommand.shell_command(self)
+ ' | ' + mara_db.shell.copy_command(
mara_db.dbs.SQLiteDB(file_name=pathlib.Path(config.data_dir()).absolute() / self.sqlite_file_name),
Expand Down Expand Up @@ -188,10 +188,10 @@ def __init__(self, file_name: str, target_table: str, make_unique: bool = False,
self.timezone = timezone
self.pipe_format = pipe_format

def db_alias(self):
def db_alias(self) -> str:
return self._db_alias or config.default_db_alias()

def shell_command(self):
def shell_command(self) -> str:
return f'{shlex.quote(sys.executable)} "{self.file_path()}" \\\n' \
+ (' | sort -u \\\n' if self.make_unique else '') \
+ ' | ' + mara_db.shell.copy_from_stdin_command(
Expand All @@ -200,7 +200,7 @@ def shell_command(self):
null_value_string=self.null_value_string, timezone=self.timezone,
pipe_format=self.pipe_format)

def file_path(self):
def file_path(self) -> pathlib.Path:
return self.parent.parent.base_path() / self.file_name

def html_doc_items(self) -> List[Tuple[str, str]]:
Expand All @@ -219,3 +219,57 @@ def html_doc_items(self) -> List[Tuple[str, str]]:
_.tt[json.dumps(self.null_value_string) if self.null_value_string is not None else None]),
('time zone', _.tt[self.timezone]),
(_.i['shell command'], html.highlight_syntax(self.shell_command(), 'bash'))]


class WriteFile(sql._SQLCommand):
"""Writes data from a local file"""
leo-schick marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, dest_file_name: str, sql_statement: Union[Callable, str] = None, sql_file_name: str = None,
replace: dict[str: str] = None, db_alias: str = None,
compression: Compression = Compression.NONE,
delimiter_char = None, header: bool = True,
format: formats.Format = formats.CsvFormat()) -> None:
"""
Writes the output of a sql query to a file in a specific format.

Args:
dest_file_name: destination file name
sql_statement: The statement to run as a string
sql_file_name: The name of the file to run (relative to the directory of the parent pipeline)
replace: A set of replacements to perform against the sql query `{'replace`: 'with', ..}`
db_alias: db on which the SQL statement shall run
storage_alias: storage on which the CSV file shall be saved
delimiter_char: delimiter character in CSV to separate fields. Default: ','
header: If a CSV header shall be added
"""
if compression != Compression.NONE:
raise ValueError('Currently WriteFile only supports compression NONE')
sql._SQLCommand.__init__(self, sql_statement, sql_file_name, replace)
leo-schick marked this conversation as resolved.
Show resolved Hide resolved
self.dest_file_name = dest_file_name
self._db_alias = db_alias
self.compression = compression
self.header = header
self.delimiter_char = delimiter_char
self.format = format

@property
def db_alias(self) -> str:
return self._db_alias or config.default_db_alias()

def shell_command(self) -> str:
command = super().shell_command() \
+ ' | ' + mara_db.shell.copy_to_stdout_command( \
self.db_alias, header=self.header, footer=None, delimiter_char=self.delimiter_char, \
csv_format=None, pipe_format=self.format) +' \\\n'
return command \
+ f' > "{pathlib.Path(config.data_dir()) / self.dest_file_name}"'

def html_doc_items(self) -> List[Tuple[str, str]]:
return [('db', _.tt[self.db_alias])
] \
+ sql._SQLCommand.html_doc_items(self, self.db_alias) \
+ [('format', _.tt[self.format]),
('destination file name', _.tt[self.dest_file_name]),
('delimiter char', _.tt[self.delimiter_char]),
('header', _.tt[str(self.header)]),
(_.i['shell command'], html.highlight_syntax(self.shell_command(), 'bash'))]
22 changes: 22 additions & 0 deletions tests/db_test_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import sqlalchemy
from mara_db import dbs


def db_is_responsive(db: dbs.DB) -> bool:
"""Returns True when the DB is available on the given port, otherwise False"""
engine = sqlalchemy.create_engine(db.sqlalchemy_url, pool_pre_ping=True)

try:
with engine.connect() as conn:
return True
except:
return False


def db_replace_placeholders(db: dbs.DB, docker_ip: str, docker_port: int) -> dbs.DB:
"""Replaces the internal placeholders with the docker ip and docker port"""
if db.host == 'DOCKER_IP':
db.host = docker_ip
if db.port == -1:
db.port = docker_port
return db
8 changes: 8 additions & 0 deletions tests/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,11 @@ services:
POSTGRES_HOST_AUTH_METHOD: md5
ports:
- "5432"

mssql:
image: mcr.microsoft.com/mssql/server:2019-latest
environment:
- ACCEPT_EULA=Y
- SA_PASSWORD=YourStrong@Passw0rd
ports:
- "1433"
10 changes: 10 additions & 0 deletions tests/local_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# This file contains secrets used by the tests

from mara_db import dbs

# supported placeholders
# host='DOCKER_IP' will be replaced with the ip address given from pytest-docker
# port=-1 will be replaced with the ip address given from pytest-docker

POSTGRES_DB = dbs.PostgreSQLDB(host='DOCKER_IP', port=-1, user="mara", password="mara", database="mara")
Copy link
Member

@jankatins jankatins Mar 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not pick that from an env var and add a fixture to set the env var? Or just create this in the fixture, where you have all the information (and return the DB)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yee... I am not 100% sure about that. This is actually a copy from the test suite of the mara_db project. See local_config.py.example.

The main purpose that I defined it in a separate config file is that you have the option to simply activate / disable the tests for specific database engines. Especially for test against cloud services this is handy. I don't want to share the credentials for my cloud with everybody but at the same time share the option for those who want to test their changes against the cloud ;-)

Removing this and integrating it into the fixture makes sence for now, but maybe not in the future...

MSSQL_SQLCMD_DB = dbs.SqlcmdSQLServerDB(host='DOCKER_IP', port=-1, user='sa', password='YourStrong@Passw0rd', database='master', trust_server_certificate=True)
99 changes: 99 additions & 0 deletions tests/postgres/test_postgres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
import pathlib
import pytest
import typing as t

from mara_app.monkey_patch import patch
import mara_db.config
import mara_db.format
import mara_pipelines.config
from mara_pipelines.commands.files import WriteFile
from mara_pipelines.commands.sql import ExecuteSQL
from mara_pipelines.pipelines import Pipeline, Task
from mara_pipelines.ui.cli import run_pipeline

from tests.db_test_helper import db_is_responsive, db_replace_placeholders
from tests.local_config import POSTGRES_DB


if not POSTGRES_DB:
pytest.skip("skipping PostgreSQL tests: variable POSTGRES_DB not set", allow_module_level=True)


# Configuration of test pipeline
patch(mara_db.config.databases)(lambda: {'dwh': POSTGRES_DB})
patch(mara_pipelines.config.default_db_alias)(lambda: 'dwh')


@pytest.fixture(scope="session")
def postgres_db(docker_ip, docker_services) -> t.Tuple[str, int]:
"""Ensures that PostgreSQL server is running on docker."""

docker_port = docker_services.port_for("postgres", 5432)
db = db_replace_placeholders(POSTGRES_DB, docker_ip, docker_port)

# here we need to wait until the PostgreSQL port is available.
docker_services.wait_until_responsive(
timeout=30.0, pause=0.1, check=lambda: db_is_responsive(db)
)

return db


@pytest.mark.dependency()
def test_postgres_command_WriteFile(postgres_db):

pipeline = Pipeline(
id='test_postgres_command_write_file',
description="",
base_path=pathlib.Path(__file__).parent)

pipeline.add(
Task(id='initial_ddl',
description="",
commands=[ExecuteSQL("""
DROP TABLE IF EXISTS "test_postgres_command_WriteFile";

CREATE TABLE "test_postgres_command_WriteFile"
(
Id INT IDENTITY(1,1) NOT NULL,
LongTest1 TEXT
LongTest2 TEXT
);
INSERT INTO "test_postgres_command_WriteFile" (
LongText1, LongText2
) VALUES
('Hello', 'World!'),
('He lo', ' orld! '),
('Hello\t', ', World! '),
""")]))

pipeline.add(
Task(id='write_file_csv',
description="Wirte content of table to file",
commands=[WriteFile(dest_file_name='_tmp-write-file.csv',
sql_statement="""SELECT * FROM "test_postgres_command_WriteFile";""",
delimiter_char=',', header=False)]))

pipeline.add(
Task(id='write_file_tsv',
description="Wirte content of table to file",
commands=[WriteFile(dest_file_name='_tmp-write-file.tsv',
sql_statement="""SELECT * FROM "test_postgres_command_WriteFile";""",
delimiter_char='\t', header=False)]))

assert run_pipeline(pipeline)

files = [
str((pipeline.base_path / '_tmp-write-file.csv').absolute()),
str((pipeline.base_path / '_tmp-write-file.tsv').absolute())
]

file_not_found = []
for file in files:
if not os.path.exists(file):
file_not_found.append(file)
else:
os.remove(file)

assert not file_not_found
5 changes: 5 additions & 0 deletions tests/postgres/test_postgres_ddl.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREATE TABLE IF NOT EXISTS names
(
id INT,
name TEXT
);