Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding fnmatch type regex to SFTPSensor #24084

Merged
19 changes: 19 additions & 0 deletions airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import datetime
import stat
import warnings
from fnmatch import fnmatch
from typing import Any, Dict, List, Optional, Tuple

import pysftp
Expand Down Expand Up @@ -329,3 +330,21 @@ def test_connection(self) -> Tuple[bool, str]:
return True, "Connection successfully tested"
except Exception as e:
return False, str(e)

def get_file_by_pattern(self, path, fnmatch_pattern) -> str:
"""
Returning the first matching file based on the given fnmatch type pattern

:param path: path to be checked
:param fnmatch_pattern: The pattern that will be matched with `fnmatch`
:return: string containing the first found file, or an empty string if none matched
"""
files_list = self.list_directory(path)

for file in files_list:
if not fnmatch(file, fnmatch_pattern):
pass
else:
return file

return ""
18 changes: 15 additions & 3 deletions airflow/providers/sftp/sensors/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class SFTPSensor(BaseSensorOperator):
Waits for a file or directory to be present on SFTP.

:param path: Remote file or directory path
:param file_pattern: The regex pattern that will be used to match the file (fnmatch format)
potiuk marked this conversation as resolved.
Show resolved Hide resolved
:param sftp_conn_id: The connection to run the sensor against
:param newer_than: DateTime for which the file or file path should be newer than, comparison is inclusive
"""
Expand All @@ -47,22 +48,33 @@ def __init__(
self,
*,
path: str,
file_pattern: str = "",
newer_than: Optional[datetime] = None,
sftp_conn_id: str = 'sftp_default',
**kwargs,
) -> None:
super().__init__(**kwargs)
self.path = path
self.file_pattern = file_pattern
self.hook: Optional[SFTPHook] = None
self.sftp_conn_id = sftp_conn_id
self.newer_than: Optional[datetime] = newer_than
self.actual_file_to_check = self.path

def poke(self, context: 'Context') -> bool:
self.hook = SFTPHook(self.sftp_conn_id)
self.log.info('Poking for %s', self.path)
self.log.info(f"Poking for {self.path}, with pattern {self.file_pattern}")

if self.file_pattern:
file_from_pattern = self.hook.get_file_by_pattern(self.path, self.file_pattern)
if file_from_pattern:
self.actual_file_to_check = file_from_pattern
else:
return False

try:
mod_time = self.hook.get_mod_time(self.path)
self.log.info('Found File %s last modified: %s', str(self.path), str(mod_time))
mod_time = self.hook.get_mod_time(self.actual_file_to_check)
self.log.info('Found File %s last modified: %s', str(self.actual_file_to_check), str(mod_time))
except OSError as e:
if e.errno != SFTP_NO_SUCH_FILE:
raise e
Expand Down
38 changes: 35 additions & 3 deletions tests/providers/sftp/hooks/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def generate_host_key(pkey: paramiko.PKey):
TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir'
SUB_DIR = "sub_dir"
TMP_FILE_FOR_TESTS = 'test_file.txt'
ANOTHER_FILE_FOR_TESTS = 'test_file_1.txt'
LOG_FILE_FOR_TESTS = 'test_log.log'

SFTP_CONNECTION_USER = "root"

Expand All @@ -60,13 +62,18 @@ def update_connection(self, login, session=None):
session.commit()
return old_login

def _create_additional_test_file(self, file_name):
with open(os.path.join(TMP_PATH, file_name), 'a') as file:
file.write('Test file')

def setUp(self):
self.old_login = self.update_connection(SFTP_CONNECTION_USER)
self.hook = SFTPHook()
os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR))

with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file:
file.write('Test file')
for file_name in [TMP_FILE_FOR_TESTS, ANOTHER_FILE_FOR_TESTS, LOG_FILE_FOR_TESTS]:
with open(os.path.join(TMP_PATH, file_name), 'a') as file:
file.write('Test file')
with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file:
file.write('Test file')

Expand Down Expand Up @@ -353,7 +360,32 @@ def test_deprecation_ftp_conn_id(self, mock_get_connection):
# Default is 'sftp_default
assert SFTPHook().ssh_conn_id == 'sftp_default'

def test_get_suffix_pattern_match(self):
output = self.hook.get_file_by_pattern(TMP_PATH, "*.txt")
self.assertTrue(output, TMP_FILE_FOR_TESTS)

def test_get_prefix_pattern_match(self):
output = self.hook.get_file_by_pattern(TMP_PATH, "test*")
self.assertTrue(output, TMP_FILE_FOR_TESTS)

def test_get_pattern_not_match(self):
output = self.hook.get_file_by_pattern(TMP_PATH, "*.text")
self.assertFalse(output)

def test_get_several_pattern_match(self):
output = self.hook.get_file_by_pattern(TMP_PATH, "*.log")
self.assertEqual(LOG_FILE_FOR_TESTS, output)

def test_get_first_pattern_match(self):
output = self.hook.get_file_by_pattern(TMP_PATH, "test_*.txt")
self.assertEqual(TMP_FILE_FOR_TESTS, output)

def test_get_middle_pattern_match(self):
output = self.hook.get_file_by_pattern(TMP_PATH, "*_file_*.txt")
self.assertEqual(ANOTHER_FILE_FOR_TESTS, output)

def tearDown(self):
shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
for file_name in [TMP_FILE_FOR_TESTS, ANOTHER_FILE_FOR_TESTS, LOG_FILE_FOR_TESTS]:
os.remove(os.path.join(TMP_PATH, file_name))
self.update_connection(self.old_login)
28 changes: 28 additions & 0 deletions tests/providers/sftp/sensors/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,31 @@ def test_naive_datetime(self, sftp_hook_mock):
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/1970-01-01.txt')
assert not output

@patch('airflow.providers.sftp.sensors.sftp.SFTPHook')
def test_file_with_pattern_parameter_call(self, sftp_hook_mock):
sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000'
sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/', file_pattern="*.txt")
context = {'ds': '1970-01-01'}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_file_by_pattern.assert_called_once_with('/path/to/file/', '*.txt')
assert output

@patch('airflow.providers.sftp.sensors.sftp.SFTPHook')
def test_file_present_with_pattern(self, sftp_hook_mock):
sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000'
sftp_hook_mock.return_value.get_file_by_pattern.return_value = '/path/to/file/text_file.txt'
sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/', file_pattern="*.txt")
context = {'ds': '1970-01-01'}
output = sftp_sensor.poke(context)
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with('/path/to/file/text_file.txt')
assert output

@patch('airflow.providers.sftp.sensors.sftp.SFTPHook')
def test_file_not_present_with_pattern(self, sftp_hook_mock):
sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000'
sftp_hook_mock.return_value.get_file_by_pattern.return_value = ""
sftp_sensor = SFTPSensor(task_id='unit_test', path='/path/to/file/', file_pattern="*.txt")
context = {'ds': '1970-01-01'}
output = sftp_sensor.poke(context)
assert not output