diff --git a/airflow/providers/mysql/transfers/vertica_to_mysql.py b/airflow/providers/mysql/transfers/vertica_to_mysql.py index 0c37d4094225b7..e273e5957453e2 100644 --- a/airflow/providers/mysql/transfers/vertica_to_mysql.py +++ b/airflow/providers/mysql/transfers/vertica_to_mysql.py @@ -94,63 +94,72 @@ def execute(self, context: 'Context'): vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) - tmpfile = None - result = None + if self.bulk_load: + self._bulk_load_transfer(mysql, vertica) + else: + self._non_bulk_load_transfer(mysql, vertica) - selected_columns = [] + if self.mysql_postoperator: + self.log.info("Running MySQL postoperator...") + mysql.run(self.mysql_postoperator) - count = 0 + self.log.info("Done") + + def _non_bulk_load_transfer(self, mysql, vertica): with closing(vertica.get_conn()) as conn: with closing(conn.cursor()) as cursor: cursor.execute(self.sql) selected_columns = [d.name for d in cursor.description] + self.log.info("Selecting rows from Vertica...") + self.log.info(self.sql) - if self.bulk_load: - with NamedTemporaryFile("w") as tmpfile: - self.log.info("Selecting rows from Vertica to local file %s...", tmpfile.name) - self.log.info(self.sql) + result = cursor.fetchall() + count = len(result) - csv_writer = csv.writer(tmpfile, delimiter='\t', encoding='utf-8') - for row in cursor.iterate(): - csv_writer.writerow(row) - count += 1 + self.log.info("Selected rows from Vertica %s", count) + self._run_preoperator(mysql) + try: + self.log.info("Inserting rows into MySQL...") + mysql.insert_rows(table=self.mysql_table, rows=result, target_fields=selected_columns) + self.log.info("Inserted rows into MySQL %s", count) + except (MySQLdb.Error, MySQLdb.Warning): + self.log.info("Inserted rows into MySQL 0") + raise - tmpfile.flush() - else: - self.log.info("Selecting rows from Vertica...") + def _bulk_load_transfer(self, mysql, vertica): + count = 0 + with closing(vertica.get_conn()) as conn: + with closing(conn.cursor()) as cursor: + cursor.execute(self.sql) + selected_columns = [d.name for d in cursor.description] + with NamedTemporaryFile("w") as tmpfile: + self.log.info("Selecting rows from Vertica to local file %s...", tmpfile.name) self.log.info(self.sql) - result = cursor.fetchall() - count = len(result) - - self.log.info("Selected rows from Vertica %s", count) - - if self.mysql_preoperator: - self.log.info("Running MySQL preoperator...") - mysql.run(self.mysql_preoperator) + csv_writer = csv.writer(tmpfile, delimiter='\t', encoding='utf-8') + for row in cursor.iterate(): + csv_writer.writerow(row) + count += 1 + tmpfile.flush() + self._run_preoperator(mysql) try: - if self.bulk_load: - self.log.info("Bulk inserting rows into MySQL...") - with closing(mysql.get_conn()) as conn: - with closing(conn.cursor()) as cursor: - cursor.execute( - f"LOAD DATA LOCAL INFILE '{tmpfile.name}' " - f"INTO TABLE {self.mysql_table} " - f"LINES TERMINATED BY '\r\n' ({', '.join(selected_columns)})" - ) - conn.commit() - tmpfile.close() - else: - self.log.info("Inserting rows into MySQL...") - mysql.insert_rows(table=self.mysql_table, rows=result, target_fields=selected_columns) + self.log.info("Bulk inserting rows into MySQL...") + with closing(mysql.get_conn()) as conn: + with closing(conn.cursor()) as cursor: + cursor.execute( + f"LOAD DATA LOCAL INFILE '{tmpfile.name}' " + f"INTO TABLE {self.mysql_table} " + f"LINES TERMINATED BY '\r\n' ({', '.join(selected_columns)})" + ) + conn.commit() + tmpfile.close() self.log.info("Inserted rows into MySQL %s", count) except (MySQLdb.Error, MySQLdb.Warning): self.log.info("Inserted rows into MySQL 0") raise - if self.mysql_postoperator: - self.log.info("Running MySQL postoperator...") - mysql.run(self.mysql_postoperator) - - self.log.info("Done") + def _run_preoperator(self, mysql): + if self.mysql_preoperator: + self.log.info("Running MySQL preoperator...") + mysql.run(self.mysql_preoperator)