Skip to content

Commit

Permalink
Support all arguments in MySQL Tasks' run methods and support SSL con…
Browse files Browse the repository at this point in the history
…figuration (PrefectHQ#4545)

* update mysql task to have full options from task run method and also support ssl options

* make existing init position args into keyword args

* run black formatter on mysql task

* fix indentation on mysQLFetch args docstring

* fix documentation, remove arg docs that arent real

* default run kwargs to None since theyll get loaded from init if available

* update ssl docstring to specify minimum options required for an ssl connection
  • Loading branch information
tchoedak authored Jun 1, 2021
1 parent 198fc27 commit 438413e
Showing 1 changed file with 87 additions and 21 deletions.
108 changes: 87 additions & 21 deletions src/prefect/tasks/mysql/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,23 @@ class MySQLExecute(Task):
- query (str, optional): query to execute against database
- commit (bool, optional): set to True to commit transaction, defaults to false
- charset (str, optional): charset you want to use (defaults to utf8mb4)
- ssl (dict, optional): A dict of arguments similar to mysql_ssl_set()’s
parameters used for establishing encrypted connections using SSL
- **kwargs (Any, optional): additional keyword arguments to pass to the
Task constructor
"""

def __init__(
self,
db_name: str,
user: str,
password: str,
host: str,
db_name: str = None,
user: str = None,
password: str = None,
host: str = None,
port: int = 3306,
query: str = None,
commit: bool = False,
charset: str = "utf8mb4",
ssl: dict = None,
**kwargs: Any,
):
self.db_name = db_name
Expand All @@ -44,17 +47,48 @@ def __init__(
self.query = query
self.commit = commit
self.charset = charset
self.ssl = ssl
super().__init__(**kwargs)

@defaults_from_attrs("query", "commit", "charset")
def run(self, query: str, commit: bool = False, charset: str = "utf8mb4") -> int:
@defaults_from_attrs(
"db_name",
"user",
"password",
"host",
"port",
"query",
"commit",
"charset",
"ssl",
)
def run(
self,
db_name: str = None,
user: str = None,
password: str = None,
host: str = None,
port: int = None,
query: str = None,
commit: bool = None,
charset: str = None,
ssl: dict = None,
) -> int:
"""
Task run method. Executes a query against MySQL database.
Args:
- db_name (str): name of MySQL database
- user (str): user name used to authenticate
- password (str): password used to authenticate
- host (str): database host address
- port (int, optional): port used to connect to MySQL database, defaults to 3307
if not provided
- query (str, optional): query to execute against database
- commit (bool, optional): set to True to commit transaction, defaults to False
- charset (str, optional): charset of the query, defaults to "utf8mb4"
- commit (bool, optional): set to True to commit transaction, defaults to false
- charset (str, optional): charset you want to use (defaults to "utf8mb4")
- ssl (dict, optional): A dict of arguments similar to mysql_ssl_set()’s
parameters used for establishing encrypted connections using SSL. To connect
with SSL, at least `ssl_ca`, `ssl_cert`, and `ssl_key` must be specified.
Returns:
- executed (int): number of affected rows
Expand All @@ -72,6 +106,7 @@ def run(self, query: str, commit: bool = False, charset: str = "utf8mb4") -> int
db=self.db_name,
charset=self.charset,
port=self.port,
ssl=ssl,
)

try:
Expand Down Expand Up @@ -111,23 +146,27 @@ class MySQLFetch(Task):
- cursor_type (Union[str, Callable], optional): The cursor type to use.
Can be `'cursor'` (the default), `'dictcursor'`, `'sscursor'`, `'ssdictcursor'`,
or a full cursor class.
- ssl (dict, optional): A dict of arguments similar to mysql_ssl_set()’s
parameters used for establishing encrypted connections using SSL. To connect
with SSL, at least `ssl_ca`, `ssl_cert`, and `ssl_key` must be specified.
- **kwargs (Any, optional): additional keyword arguments to pass to the
Task constructor
"""

def __init__(
self,
db_name: str,
user: str,
password: str,
host: str,
db_name: str = None,
user: str = None,
password: str = None,
host: str = None,
port: int = 3306,
fetch: str = "one",
fetch_count: int = 10,
query: str = None,
commit: bool = False,
charset: str = "utf8mb4",
cursor_type: Union[str, Callable] = "cursor",
ssl: dict = None,
**kwargs: Any,
):
self.db_name = db_name
Expand All @@ -141,34 +180,60 @@ def __init__(
self.commit = commit
self.charset = charset
self.cursor_type = cursor_type
self.ssl = ssl
super().__init__(**kwargs)

@defaults_from_attrs(
"fetch", "fetch_count", "query", "commit", "charset", "cursor_type"
"db_name",
"user",
"password",
"host",
"port",
"fetch",
"fetch_count",
"query",
"commit",
"charset",
"cursor_type",
"ssl",
)
def run(
self,
query: str,
fetch: str = "one",
fetch_count: int = 10,
commit: bool = False,
charset: str = "utf8mb4",
cursor_type: Union[str, Callable] = "cursor",
db_name: str = None,
user: str = None,
password: str = None,
host: str = None,
port: int = None,
fetch: str = None,
fetch_count: int = None,
query: str = None,
commit: bool = None,
charset: str = None,
cursor_type: Union[str, Callable] = None,
ssl: dict = None,
) -> Any:
"""
Task run method. Executes a query against MySQL database and fetches results.
Args:
- db_name (str): name of MySQL database
- user (str): user name used to authenticate
- password (str): password used to authenticate
- host (str): database host address
- port (int, optional): port used to connect to MySQL database, defaults to 3307 if not
provided
- fetch (str, optional): one of "one" "many" or "all", used to determine how many
results to fetch from executed query
- fetch_count (int, optional): if fetch = 'many', determines the number of results
to fetch, defaults to 10
- fetch_count (int, optional): if fetch = 'many', determines the number of results to
fetch, defaults to 10
- query (str, optional): query to execute against database
- commit (bool, optional): set to True to commit transaction, defaults to false
- charset (str, optional): charset of the query, defaults to "utf8mb4"
- cursor_type (Union[str, Callable], optional): The cursor type to use.
Can be `'cursor'` (the default), `'dictcursor'`, `'sscursor'`, `'ssdictcursor'`,
or a full cursor class.
- ssl (dict, optional): A dict of arguments similar to mysql_ssl_set()’s
parameters used for establishing encrypted connections using SSL
Returns:
- results (tuple or list of tuples): records from provided query
Expand Down Expand Up @@ -212,6 +277,7 @@ def run(
charset=self.charset,
port=self.port,
cursorclass=cursor_class,
ssl=ssl,
)

try:
Expand Down

0 comments on commit 438413e

Please sign in to comment.