diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 8e65ef78ab8fe..a70ddecc68150 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -1255,6 +1255,18 @@ type: string example: ~ default: "airflow@example.com" + - name: smtp_timeout + description: ~ + version_added: ~ + type: int + example: ~ + default: "30" + - name: smtp_retry_limit + description: ~ + version_added: ~ + type: int + example: ~ + default: "5" - name: sentry description: | Sentry (https://docs.sentry.io) integration. Here you can supply diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 1d670d4c4ba61..3cf316b285d6c 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -611,6 +611,8 @@ smtp_ssl = False # smtp_password = smtp_port = 25 smtp_mail_from = airflow@example.com +smtp_timeout = 30 +smtp_retry_limit = 5 [sentry] diff --git a/airflow/config_templates/default_test.cfg b/airflow/config_templates/default_test.cfg index 824565d70faf8..767176d7fdade 100644 --- a/airflow/config_templates/default_test.cfg +++ b/airflow/config_templates/default_test.cfg @@ -87,6 +87,8 @@ smtp_user = airflow smtp_port = 25 smtp_password = airflow smtp_mail_from = airflow@example.com +smtp_retry_limit = 5 +smtp_timeout = 30 [celery] celery_app_name = airflow.executors.celery_executor diff --git a/airflow/utils/email.py b/airflow/utils/email.py index 17d3b6e97bd7c..8e4359bd0ebf9 100644 --- a/airflow/utils/email.py +++ b/airflow/utils/email.py @@ -168,6 +168,8 @@ def send_mime_email(e_from: str, e_to: List[str], mime_msg: MIMEMultipart, dryru smtp_port = conf.getint('smtp', 'SMTP_PORT') smtp_starttls = conf.getboolean('smtp', 'SMTP_STARTTLS') smtp_ssl = conf.getboolean('smtp', 'SMTP_SSL') + smtp_retry_limit = conf.getint('smtp', 'SMTP_RETRY_LIMIT') + smtp_timeout = conf.getint('smtp', 'SMTP_TIMEOUT') smtp_user = None smtp_password = None @@ -178,14 +180,23 @@ def send_mime_email(e_from: str, e_to: List[str], mime_msg: MIMEMultipart, dryru log.debug("No user/password found for SMTP, so logging in with no authentication.") if not dryrun: - conn = smtplib.SMTP_SSL(smtp_host, smtp_port) if smtp_ssl else smtplib.SMTP(smtp_host, smtp_port) - if smtp_starttls: - conn.starttls() - if smtp_user and smtp_password: - conn.login(smtp_user, smtp_password) - log.info("Sent an alert email to %s", e_to) - conn.sendmail(e_from, e_to, mime_msg.as_string()) - conn.quit() + for attempt in range(1, smtp_retry_limit + 1): + log.info("Email alerting: attempt %s", str(attempt)) + try: + conn = _get_smtp_connection(smtp_host, smtp_port, smtp_timeout, smtp_ssl) + except smtplib.SMTPServerDisconnected: + if attempt < smtp_retry_limit: + continue + raise + + if smtp_starttls: + conn.starttls() + if smtp_user and smtp_password: + conn.login(smtp_user, smtp_password) + log.info("Sent an alert email to %s", e_to) + conn.sendmail(e_from, e_to, mime_msg.as_string()) + conn.quit() + break def get_email_address_list(addresses: Union[str, Iterable[str]]) -> List[str]: @@ -202,6 +213,14 @@ def get_email_address_list(addresses: Union[str, Iterable[str]]) -> List[str]: raise TypeError(f"Unexpected argument type: Received '{received_type}'.") +def _get_smtp_connection(host: str, port: int, timeout: int, with_ssl: bool) -> smtplib.SMTP: + return ( + smtplib.SMTP_SSL(host=host, port=port, timeout=timeout) + if with_ssl + else smtplib.SMTP(host=host, port=port, timeout=timeout) + ) + + def _get_email_list_from_str(addresses: str) -> List[str]: delimiters = [",", ";"] for delimiter in delimiters: diff --git a/tests/utils/test_email.py b/tests/utils/test_email.py index 8081fc85a74cf..8966e975d3647 100644 --- a/tests/utils/test_email.py +++ b/tests/utils/test_email.py @@ -21,6 +21,7 @@ from email.mime.application import MIMEApplication from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +from smtplib import SMTPServerDisconnected from unittest import mock from airflow import utils @@ -118,7 +119,6 @@ def test_build_mime_message(self): self.assertEqual(msg['To'], ','.join(recipients)) -@conf_vars({('smtp', 'SMTP_SSL'): 'False'}) class TestEmailSmtp(unittest.TestCase): @mock.patch('airflow.utils.email.send_mime_email') def test_send_smtp(self, mock_send_mime): @@ -127,10 +127,10 @@ def test_send_smtp(self, mock_send_mime): attachment.seek(0) utils.email.send_email_smtp('to', 'subject', 'content', files=[attachment.name]) self.assertTrue(mock_send_mime.called) - call_args = mock_send_mime.call_args[0] - self.assertEqual(conf.get('smtp', 'SMTP_MAIL_FROM'), call_args[0]) - self.assertEqual(['to'], call_args[1]) - msg = call_args[2] + _, call_args = mock_send_mime.call_args + self.assertEqual(conf.get('smtp', 'SMTP_MAIL_FROM'), call_args['e_from']) + self.assertEqual(['to'], call_args['e_to']) + msg = call_args['mime_msg'] self.assertEqual('subject', msg['Subject']) self.assertEqual(conf.get('smtp', 'SMTP_MAIL_FROM'), msg['From']) self.assertEqual(2, len(msg.get_payload())) @@ -143,8 +143,8 @@ def test_send_smtp(self, mock_send_mime): def test_send_smtp_with_multibyte_content(self, mock_send_mime): utils.email.send_email_smtp('to', 'subject', '🔥', mime_charset='utf-8') self.assertTrue(mock_send_mime.called) - call_args = mock_send_mime.call_args[0] - msg = call_args[2] + _, call_args = mock_send_mime.call_args + msg = call_args['mime_msg'] mimetext = MIMEText('🔥', 'mixed', 'utf-8') self.assertEqual(mimetext.get_payload(), msg.get_payload()[0].get_payload()) @@ -155,10 +155,10 @@ def test_send_bcc_smtp(self, mock_send_mime): attachment.seek(0) utils.email.send_email_smtp('to', 'subject', 'content', files=[attachment.name], cc='cc', bcc='bcc') self.assertTrue(mock_send_mime.called) - call_args = mock_send_mime.call_args[0] - self.assertEqual(conf.get('smtp', 'SMTP_MAIL_FROM'), call_args[0]) - self.assertEqual(['to', 'cc', 'bcc'], call_args[1]) - msg = call_args[2] + _, call_args = mock_send_mime.call_args + self.assertEqual(conf.get('smtp', 'SMTP_MAIL_FROM'), call_args['e_from']) + self.assertEqual(['to', 'cc', 'bcc'], call_args['e_to']) + msg = call_args['mime_msg'] self.assertEqual('subject', msg['Subject']) self.assertEqual(conf.get('smtp', 'SMTP_MAIL_FROM'), msg['From']) self.assertEqual(2, len(msg.get_payload())) @@ -173,13 +173,14 @@ def test_send_bcc_smtp(self, mock_send_mime): @mock.patch('smtplib.SMTP') def test_send_mime(self, mock_smtp, mock_smtp_ssl): mock_smtp.return_value = mock.Mock() - mock_smtp_ssl.return_value = mock.Mock() msg = MIMEMultipart() utils.email.send_mime_email('from', 'to', msg, dryrun=False) mock_smtp.assert_called_once_with( - conf.get('smtp', 'SMTP_HOST'), - conf.getint('smtp', 'SMTP_PORT'), + host=conf.get('smtp', 'SMTP_HOST'), + port=conf.getint('smtp', 'SMTP_PORT'), + timeout=conf.getint('smtp', 'SMTP_TIMEOUT'), ) + self.assertFalse(mock_smtp_ssl.called) self.assertTrue(mock_smtp.return_value.starttls.called) mock_smtp.return_value.login.assert_called_once_with( conf.get('smtp', 'SMTP_USER'), @@ -191,21 +192,20 @@ def test_send_mime(self, mock_smtp, mock_smtp_ssl): @mock.patch('smtplib.SMTP_SSL') @mock.patch('smtplib.SMTP') def test_send_mime_ssl(self, mock_smtp, mock_smtp_ssl): - mock_smtp.return_value = mock.Mock() mock_smtp_ssl.return_value = mock.Mock() with conf_vars({('smtp', 'smtp_ssl'): 'True'}): utils.email.send_mime_email('from', 'to', MIMEMultipart(), dryrun=False) self.assertFalse(mock_smtp.called) mock_smtp_ssl.assert_called_once_with( - conf.get('smtp', 'SMTP_HOST'), - conf.getint('smtp', 'SMTP_PORT'), + host=conf.get('smtp', 'SMTP_HOST'), + port=conf.getint('smtp', 'SMTP_PORT'), + timeout=conf.getint('smtp', 'SMTP_TIMEOUT'), ) @mock.patch('smtplib.SMTP_SSL') @mock.patch('smtplib.SMTP') def test_send_mime_noauth(self, mock_smtp, mock_smtp_ssl): mock_smtp.return_value = mock.Mock() - mock_smtp_ssl.return_value = mock.Mock() with conf_vars( { ('smtp', 'smtp_user'): None, @@ -215,8 +215,9 @@ def test_send_mime_noauth(self, mock_smtp, mock_smtp_ssl): utils.email.send_mime_email('from', 'to', MIMEMultipart(), dryrun=False) self.assertFalse(mock_smtp_ssl.called) mock_smtp.assert_called_once_with( - conf.get('smtp', 'SMTP_HOST'), - conf.getint('smtp', 'SMTP_PORT'), + host=conf.get('smtp', 'SMTP_HOST'), + port=conf.getint('smtp', 'SMTP_PORT'), + timeout=conf.getint('smtp', 'SMTP_TIMEOUT'), ) self.assertFalse(mock_smtp.login.called) @@ -226,3 +227,89 @@ def test_send_mime_dryrun(self, mock_smtp, mock_smtp_ssl): utils.email.send_mime_email('from', 'to', MIMEMultipart(), dryrun=True) self.assertFalse(mock_smtp.called) self.assertFalse(mock_smtp_ssl.called) + + @mock.patch('smtplib.SMTP_SSL') + @mock.patch('smtplib.SMTP') + def test_send_mime_complete_failure(self, mock_smtp: mock, mock_smtp_ssl): + mock_smtp.side_effect = SMTPServerDisconnected() + msg = MIMEMultipart() + with self.assertRaises(SMTPServerDisconnected): + utils.email.send_mime_email('from', 'to', msg, dryrun=False) + + mock_smtp.assert_any_call( + host=conf.get('smtp', 'SMTP_HOST'), + port=conf.getint('smtp', 'SMTP_PORT'), + timeout=conf.getint('smtp', 'SMTP_TIMEOUT'), + ) + self.assertEqual(mock_smtp.call_count, conf.getint('smtp', 'SMTP_RETRY_LIMIT')) + self.assertFalse(mock_smtp_ssl.called) + self.assertFalse(mock_smtp.return_value.starttls.called) + self.assertFalse(mock_smtp.return_value.login.called) + self.assertFalse(mock_smtp.return_value.sendmail.called) + self.assertFalse(mock_smtp.return_value.quit.called) + + @mock.patch('smtplib.SMTP_SSL') + @mock.patch('smtplib.SMTP') + def test_send_mime_ssl_complete_failure(self, mock_smtp, mock_smtp_ssl): + mock_smtp_ssl.side_effect = SMTPServerDisconnected() + msg = MIMEMultipart() + with conf_vars({('smtp', 'smtp_ssl'): 'True'}): + with self.assertRaises(SMTPServerDisconnected): + utils.email.send_mime_email('from', 'to', msg, dryrun=False) + + mock_smtp_ssl.assert_any_call( + host=conf.get('smtp', 'SMTP_HOST'), + port=conf.getint('smtp', 'SMTP_PORT'), + timeout=conf.getint('smtp', 'SMTP_TIMEOUT'), + ) + self.assertEqual(mock_smtp_ssl.call_count, conf.getint('smtp', 'SMTP_RETRY_LIMIT')) + self.assertFalse(mock_smtp.called) + self.assertFalse(mock_smtp_ssl.return_value.starttls.called) + self.assertFalse(mock_smtp_ssl.return_value.login.called) + self.assertFalse(mock_smtp_ssl.return_value.sendmail.called) + self.assertFalse(mock_smtp_ssl.return_value.quit.called) + + @mock.patch('smtplib.SMTP_SSL') + @mock.patch('smtplib.SMTP') + def test_send_mime_custom_timeout_retrylimit(self, mock_smtp, mock_smtp_ssl): + mock_smtp.side_effect = SMTPServerDisconnected() + msg = MIMEMultipart() + + custom_retry_limit = 10 + custom_timeout = 60 + + with conf_vars( + { + ('smtp', 'smtp_retry_limit'): str(custom_retry_limit), + ('smtp', 'smtp_timeout'): str(custom_timeout), + } + ): + with self.assertRaises(SMTPServerDisconnected): + utils.email.send_mime_email('from', 'to', msg, dryrun=False) + + mock_smtp.assert_any_call( + host=conf.get('smtp', 'SMTP_HOST'), port=conf.getint('smtp', 'SMTP_PORT'), timeout=custom_timeout + ) + self.assertFalse(mock_smtp_ssl.called) + self.assertEqual(mock_smtp.call_count, 10) + + @mock.patch('smtplib.SMTP_SSL') + @mock.patch('smtplib.SMTP') + def test_send_mime_partial_failure(self, mock_smtp, mock_smtp_ssl): + final_mock = mock.Mock() + side_effects = [SMTPServerDisconnected(), SMTPServerDisconnected(), final_mock] + mock_smtp.side_effect = side_effects + msg = MIMEMultipart() + + utils.email.send_mime_email('from', 'to', msg, dryrun=False) + + mock_smtp.assert_any_call( + host=conf.get('smtp', 'SMTP_HOST'), + port=conf.getint('smtp', 'SMTP_PORT'), + timeout=conf.getint('smtp', 'SMTP_TIMEOUT'), + ) + self.assertEqual(mock_smtp.call_count, side_effects.index(final_mock) + 1) + self.assertFalse(mock_smtp_ssl.called) + self.assertTrue(final_mock.starttls.called) + final_mock.sendmail.assert_called_once_with('from', 'to', msg.as_string()) + self.assertTrue(final_mock.quit.called)