Skip to content

Commit

Permalink
Refactor vertica_to_mysql to make it more 'mypy' friendly
Browse files Browse the repository at this point in the history
Part of #19891

MyPy was confused by the logic in this method (and so humans could
be) because there were some implicit relations between bulk_load
and tmpfle. This refector makes the bulk_load and non-bulk load
separate (extracting common parts) and more obvious.

Thanks MyPy for flagging this one.
  • Loading branch information
potiuk committed Jan 5, 2022
1 parent e8b5ab9 commit 9020f23
Showing 1 changed file with 51 additions and 42 deletions.
93 changes: 51 additions & 42 deletions airflow/providers/mysql/transfers/vertica_to_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,63 +94,72 @@ def execute(self, context: 'Context'):
vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id)
mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)

tmpfile = None
result = None
if self.bulk_load:
self._bulk_load_transfer(mysql, vertica)
else:
self._non_bulk_load_transfer(mysql, vertica)

selected_columns = []
if self.mysql_postoperator:
self.log.info("Running MySQL postoperator...")
mysql.run(self.mysql_postoperator)

count = 0
self.log.info("Done")

def _non_bulk_load_transfer(self, mysql, vertica):
with closing(vertica.get_conn()) as conn:
with closing(conn.cursor()) as cursor:
cursor.execute(self.sql)
selected_columns = [d.name for d in cursor.description]
self.log.info("Selecting rows from Vertica...")
self.log.info(self.sql)

if self.bulk_load:
with NamedTemporaryFile("w") as tmpfile:
self.log.info("Selecting rows from Vertica to local file %s...", tmpfile.name)
self.log.info(self.sql)
result = cursor.fetchall()
count = len(result)

csv_writer = csv.writer(tmpfile, delimiter='\t', encoding='utf-8')
for row in cursor.iterate():
csv_writer.writerow(row)
count += 1
self.log.info("Selected rows from Vertica %s", count)
self._run_preoperator(mysql)
try:
self.log.info("Inserting rows into MySQL...")
mysql.insert_rows(table=self.mysql_table, rows=result, target_fields=selected_columns)
self.log.info("Inserted rows into MySQL %s", count)
except (MySQLdb.Error, MySQLdb.Warning):
self.log.info("Inserted rows into MySQL 0")
raise

tmpfile.flush()
else:
self.log.info("Selecting rows from Vertica...")
def _bulk_load_transfer(self, mysql, vertica):
count = 0
with closing(vertica.get_conn()) as conn:
with closing(conn.cursor()) as cursor:
cursor.execute(self.sql)
selected_columns = [d.name for d in cursor.description]
with NamedTemporaryFile("w") as tmpfile:
self.log.info("Selecting rows from Vertica to local file %s...", tmpfile.name)
self.log.info(self.sql)

result = cursor.fetchall()
count = len(result)

self.log.info("Selected rows from Vertica %s", count)

if self.mysql_preoperator:
self.log.info("Running MySQL preoperator...")
mysql.run(self.mysql_preoperator)
csv_writer = csv.writer(tmpfile, delimiter='\t', encoding='utf-8')
for row in cursor.iterate():
csv_writer.writerow(row)
count += 1

tmpfile.flush()
self._run_preoperator(mysql)
try:
if self.bulk_load:
self.log.info("Bulk inserting rows into MySQL...")
with closing(mysql.get_conn()) as conn:
with closing(conn.cursor()) as cursor:
cursor.execute(
f"LOAD DATA LOCAL INFILE '{tmpfile.name}' "
f"INTO TABLE {self.mysql_table} "
f"LINES TERMINATED BY '\r\n' ({', '.join(selected_columns)})"
)
conn.commit()
tmpfile.close()
else:
self.log.info("Inserting rows into MySQL...")
mysql.insert_rows(table=self.mysql_table, rows=result, target_fields=selected_columns)
self.log.info("Bulk inserting rows into MySQL...")
with closing(mysql.get_conn()) as conn:
with closing(conn.cursor()) as cursor:
cursor.execute(
f"LOAD DATA LOCAL INFILE '{tmpfile.name}' "
f"INTO TABLE {self.mysql_table} "
f"LINES TERMINATED BY '\r\n' ({', '.join(selected_columns)})"
)
conn.commit()
tmpfile.close()
self.log.info("Inserted rows into MySQL %s", count)
except (MySQLdb.Error, MySQLdb.Warning):
self.log.info("Inserted rows into MySQL 0")
raise

if self.mysql_postoperator:
self.log.info("Running MySQL postoperator...")
mysql.run(self.mysql_postoperator)

self.log.info("Done")
def _run_preoperator(self, mysql):
if self.mysql_preoperator:
self.log.info("Running MySQL preoperator...")
mysql.run(self.mysql_preoperator)

0 comments on commit 9020f23

Please sign in to comment.