Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 115 additions & 15 deletions pytest_sftpserver/sftp/content_provider.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,120 @@
# encoding: utf-8
from __future__ import absolute_import, division, print_function

from six import binary_type, integer_types, string_types
from collections import defaultdict
import posixpath
import time

from six import string_types, binary_type, integer_types


class ContentProvider(object):
def __init__(self, content_object=None):
self.content_object = content_object

def get(self, path):
return self._find_object_for_path(path)
@property
def content_object(self):
return self._content_object

@content_object.setter
def content_object(self, data):
self._content_object = data
self._rebuild_stat_dict()

def _rebuild_stat_dict(self):
# values: [atime, mtime]
# (alphabetical order for easier remembering)
self._st_times = defaultdict(lambda : [time.time()] * 2)
the_time = time.time()
# This assumes that we start at '/'. Not sure that is
# a valid assumtion.
for path, name in self.recursive_list('/'):
self._st_times[(path, name)] = [the_time] * 2

# in rare situations, I don't want to record a change in
# atime for fetching an object -- such as using "get" to
# test if the file exists.
def get(self, path, atime_change=True):
obj = self._find_object_for_path(path)
if obj is not None and atime_change:
path, name = self._get_path_components(path)
# update atime, but not mtime
self._update_times(path, name, [time.time(), None])
return obj

def _update_times(self, path, name, times):
if len(times) != 2:
raise ValueError("'times' argument must be a length-2 list. Got %r"
% (times,))
st_times = self._st_times[(path, name)]
# Mutate the list in-place
if times[0] is not None:
st_times[0] = times[0]
if times[1] is not None:
st_times[1] = times[1]

def put(self, path, data, times=None):
if times is None:
# So don't override atime or mtime...
times = [None, None]
# Cast it as a list
times = list(times)

def put(self, path, data):
path, name = self._get_path_components(path)
dirpath, dirname = self._get_path_components(path)
obj = self._find_object_for_path(path)

if times[1] is None:
# ... but do update mtime
times[1] = time.time()

if isinstance(obj, dict):
if name not in obj:
self._update_times(dirpath, dirname, [None, time.time()])
obj[name] = data
self._update_times(path, name, times)
return True
elif isinstance(obj, list) and name.isdigit():
# Need to be done *before* casting to integer
# because code elsewhere won't cast to int
# before fetching from the dictionary.
self._update_times(path, name, times)
name = int(name)
if name > len(obj) - 1:
obj.append(data)
self._update_times(dirpath, dirname, [None, time.time()])
obj[name] = data
return True
try:
if not hasattr(obj, name):
self._update_times(dirpath, dirname, [None, time.time()])
setattr(obj, name, data)
return True
except (TypeError, AttributeError):
pass
return False
self._st_times.pop((dirpath, dirname), None)
return False
else:
self._update_times(path, name, times)
return True

def remove(self, path):
path, name = self._get_path_components(path)
dirpath, dirname = self._get_path_components(path)
obj = self._find_object_for_path(path)
if isinstance(obj, dict):
try:
del obj[name]
return True
except (KeyError, AttributeError):
pass
return False
else:
self._st_times.pop((path, name), None)
self._update_times(dirpath, dirname, [None, time.time()])
return True
elif isinstance(obj, list) and name.isdigit():
# Need to be done *before* casting to integer
# because code elsewhere won't cast to int
# before fetching from the dictionary.
self._st_times.pop((path, name), None)
self._update_times(dirpath, dirname, [None, time.time()])
name = int(name)
if name < len(obj):
del obj[name]
Expand All @@ -49,28 +124,53 @@ def remove(self, path):
else:
try:
delattr(obj, name)
return True
except (TypeError, AttributeError):
pass
return False
return False
else:
self._st_times.pop((path, name), None)
self._update_times(dirpath, dirname, [None, time.time()])
return True

def list(self, path):
obj = self._find_object_for_path(path)
if obj is not None:
dirpath, dirname = self._get_path_components(path)
self._update_times(dirpath, dirname, [time.time(), None])

if isinstance(obj, dict):
return obj.keys()
elif isinstance(obj, (list, tuple)):
return [str(i) for i in range(len(obj))]
else:
return [n for n in dir(obj) if not n.startswith("__")]

def recursive_list(self, path):
subpath, subname = self._get_path_components(path)
yield subpath if subpath != '/' else '', subname
if self.is_dir(path):
for name in self.list(path):
fullname = posixpath.join(path, name)
for subpath, subname in self.recursive_list(fullname):
yield subpath, subname

def is_dir(self, path):
return not isinstance(self.get(path), string_types + integer_types)
# Using _find_object_for_path to avoid attribute-setting
# that .get() does.
return not isinstance(self._find_object_for_path(path),
string_types + integer_types)

def get_size(self, path):
try:
return len(self.get(path))
return len(self.get(path, atime_change=False))
except TypeError:
return len(str(self.get(path)))
# This casting to string is ok. If the value was a binary string
# then we wouldn't have a TypeError anyway. This is just to
# cast non-string-likes into a string to get a usable length.
# FIXME: maybe report the data's memory size?
return len(str(self.get(path, atime_change=False)))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noticed that this would be wrong in py3k in light of #15 (which doesn't fix this).


def get_times(self, path):
return self._st_times.get(self._get_path_components(path), None)

def _find_object_for_path(self, path):
if not self.content_object:
Expand Down
52 changes: 32 additions & 20 deletions pytest_sftpserver/sftp/interface.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# encoding: utf-8
from __future__ import absolute_import, division, print_function

import calendar
from os import O_CREAT
import posixpath
import stat
from datetime import datetime
from os import O_CREAT
import time

from paramiko import AUTH_SUCCESSFUL, OPEN_SUCCEEDED, ServerInterface
from paramiko.sftp import SFTP_FAILURE, SFTP_NO_SUCH_FILE, SFTP_OK
Expand All @@ -18,25 +17,35 @@


class VirtualSFTPHandle(SFTPHandle):
def __init__(self, path, content_provider, flags=0):
def __init__(self, path, content_provider, flags=0, attr=None):
super(VirtualSFTPHandle, self).__init__()
self.path = path
self.content_provider = content_provider
if self.content_provider.get(self.path) is None and flags and flags & O_CREAT == O_CREAT:
if (self.content_provider.get(self.path, atime_change=False) is None
and flags and flags & O_CREAT == O_CREAT):
if attr is not None:
times = [getattr(attr, 'st_atime', None),
getattr(attr, 'st_mtime', None)]
else:
times = None
# Create new empty "file"
self.content_provider.put(path, "")
self.content_provider.put(path, "", times)

def close(self):
return SFTP_OK

def chattr(self, attr):
if self.content_provider.get(self.path) is None:
if self.content_provider.get(self.path, atime_change=False) is None:
return SFTP_NO_SUCH_FILE

if hasattr(attr, 'st_atime') or hasattr(attr, 'st_mtime'):
times = self.content_provider.get_times(self.path)
# Mutates the stored times list in-place
times[0] = getattr(attr, 'st_atime', times[0])
times[1] = getattr(attr, 'st_mtime', times[1])
return SFTP_OK

def write(self, offset, data):
content = self.content_provider.get(self.path)
content = self.content_provider.get(self.path, atime_change=False)

if content is None:
return SFTP_OK if self.content_provider.put(self.path, data) else SFTP_NO_SUCH_FILE
Expand All @@ -55,18 +64,16 @@ def write(self, offset, data):
return SFTP_OK if self.content_provider.put(self.path, content) else SFTP_FAILURE

def read(self, offset, length):
if self.content_provider.get(self.path) is None:
if self.content_provider.get(self.path, atime_change=False) is None:
return SFTP_NO_SUCH_FILE

end = offset + length
return self.content_provider.get(self.path)[offset:end]

def stat(self):
if self.content_provider.get(self.path) is None:
if self.content_provider.get(self.path, atime_change=False) is None:
return SFTP_NO_SUCH_FILE

mtime = calendar.timegm(datetime.now().timetuple())

sftp_attrs = SFTPAttributes()
sftp_attrs.st_size = self.content_provider.get_size(self.path)
sftp_attrs.st_uid = 0
Expand All @@ -77,8 +84,8 @@ def stat(self):
| stat.S_IRWXU
| (stat.S_IFDIR if self.content_provider.is_dir(self.path) else stat.S_IFREG)
)
sftp_attrs.st_atime = mtime
sftp_attrs.st_mtime = mtime
(sftp_attrs.st_atime,
sftp_attrs.st_mtime) = self.content_provider.get_times(self.path)
sftp_attrs.filename = posixpath.basename(self.path)
return sftp_attrs

Expand All @@ -97,18 +104,20 @@ def list_folder(self, path):

@abspath
def open(self, path, flags, attr):
return VirtualSFTPHandle(path, self.content_provider, flags=flags)
return VirtualSFTPHandle(path, self.content_provider, flags=flags,
attr=attr)

@abspath
def remove(self, path):
return SFTP_OK if self.content_provider.remove(path) else SFTP_NO_SUCH_FILE

@abspath
def rename(self, oldpath, newpath):
content = self.content_provider.get(oldpath)
content = self.content_provider.get(oldpath, atime_change=False)
if not content:
return SFTP_NO_SUCH_FILE
res = self.content_provider.put(newpath, content)
oldtimes = self.content_provider.get_times(oldpath)
res = self.content_provider.put(newpath, content, oldtimes)
if res:
res = res and self.content_provider.remove(oldpath)
return SFTP_OK if res else SFTP_FAILURE
Expand All @@ -119,9 +128,12 @@ def rmdir(self, path):

@abspath
def mkdir(self, path, attr):
if self.content_provider.get(path) is not None:
if self.content_provider.get(path, atime_change=False) is not None:
return SFTP_FAILURE
return SFTP_OK if self.content_provider.put(path, {}) else SFTP_FAILURE
times = [getattr(attr, 'st_atime', None),
getattr(attr, 'st_mtime', None)]
return (SFTP_OK if self.content_provider.put(path, {}, times)
else SFTP_FAILURE)

@abspath
def stat(self, path):
Expand Down
Loading