diff --git a/pynvim/__init__.py b/pynvim/__init__.py index 9fc0534b..9e6e6f63 100644 --- a/pynvim/__init__.py +++ b/pynvim/__init__.py @@ -6,14 +6,20 @@ import os import sys from types import SimpleNamespace as Version +from typing import List, cast, overload from pynvim.api import Nvim, NvimError -from pynvim.msgpack_rpc import (ErrorResponse, child_session, socket_session, - stdio_session, tcp_session) +from pynvim.msgpack_rpc import (ErrorResponse, Session, TTransportType, child_session, + socket_session, stdio_session, tcp_session) from pynvim.plugin import (Host, autocmd, command, decode, encoding, function, plugin, rpc_export, shutdown_hook) from pynvim.util import VERSION +if sys.version_info < (3, 8): + from typing_extensions import Literal +else: + from typing import Literal + __all__ = ('tcp_session', 'socket_session', 'stdio_session', 'child_session', 'start_host', 'autocmd', 'command', 'encoding', 'decode', @@ -22,7 +28,7 @@ 'ErrorResponse') -def start_host(session=None): +def start_host(session: Session = None) -> None: """Promote the current process into python plugin host for Nvim. Start msgpack-rpc event loop for `session`, listening for Nvim requests @@ -77,8 +83,30 @@ def start_host(session=None): host.start(plugins) -def attach(session_type, address=None, port=None, - path=None, argv=None, decode=True): +@overload +def attach(session_type: Literal['tcp'], address: str, port: int = 7450) -> Nvim: ... + + +@overload +def attach(session_type: Literal['socket'], *, path: str) -> Nvim: ... + + +@overload +def attach(session_type: Literal['child'], *, argv: List[str]) -> Nvim: ... + + +@overload +def attach(session_type: Literal['stdio']) -> Nvim: ... + + +def attach( + session_type: TTransportType, + address: str = None, + port: int = 7450, + path: str = None, + argv: List[str] = None, + decode: Literal[True] = True +) -> Nvim: """Provide a nicer interface to create python api sessions. Previous machinery to create python api sessions is still there. This only @@ -107,11 +135,13 @@ def attach(session_type, address=None, port=None, """ - session = (tcp_session(address, port) if session_type == 'tcp' else - socket_session(path) if session_type == 'socket' else - stdio_session() if session_type == 'stdio' else - child_session(argv) if session_type == 'child' else - None) + session = ( + tcp_session(cast(str, address), port) if session_type == 'tcp' else + socket_session(cast(str, path)) if session_type == 'socket' else + stdio_session() if session_type == 'stdio' else + child_session(cast(List[str], argv)) if session_type == 'child' else + None + ) if not session: raise Exception('Unknown session type "%s"' % session_type) @@ -119,7 +149,7 @@ def attach(session_type, address=None, port=None, return Nvim.from_session(session).with_decode(decode) -def setup_logging(name): +def setup_logging(name: str) -> None: """Setup logging according to environment variables.""" logger = logging.getLogger(__name__) if 'NVIM_PYTHON_LOG_FILE' in os.environ: @@ -141,13 +171,3 @@ def setup_logging(name): logger.warning('Invalid NVIM_PYTHON_LOG_LEVEL: %r, using INFO.', env_log_level) logger.setLevel(level) - - -# Required for python 2.6 -class NullHandler(logging.Handler): - def emit(self, record): - pass - - -if not logging.root.handlers: - logging.root.addHandler(NullHandler()) diff --git a/pynvim/api/buffer.py b/pynvim/api/buffer.py index cb73196d..4d4b4daf 100644 --- a/pynvim/api/buffer.py +++ b/pynvim/api/buffer.py @@ -1,12 +1,33 @@ """API for working with a Nvim Buffer.""" +from typing import (Any, Iterator, List, Optional, TYPE_CHECKING, Tuple, Union, cast, + overload) + from pynvim.api.common import Remote from pynvim.compat import check_async +if TYPE_CHECKING: + from pynvim.api import Nvim + + +__all__ = ('Buffer',) + + +@overload +def adjust_index(idx: int, default: int = None) -> int: + ... + -__all__ = ('Buffer') +@overload +def adjust_index(idx: Optional[int], default: int) -> int: + ... -def adjust_index(idx, default=None): +@overload +def adjust_index(idx: Optional[int], default: int = None) -> Optional[int]: + ... + + +def adjust_index(idx: Optional[int], default: int = None) -> Optional[int]: """Convert from python indexing convention to nvim indexing convention.""" if idx is None: return default @@ -21,12 +42,25 @@ class Buffer(Remote): """A remote Nvim buffer.""" _api_prefix = "nvim_buf_" + _session: "Nvim" + + def __init__(self, session: "Nvim", code_data: Tuple[int, Any]): + """Initialize from Nvim and code_data immutable object.""" + super().__init__(session, code_data) - def __len__(self): + def __len__(self) -> int: """Return the number of lines contained in a Buffer.""" return self.request('nvim_buf_line_count') - def __getitem__(self, idx): + @overload + def __getitem__(self, idx: int) -> str: # noqa: D105 + ... + + @overload + def __getitem__(self, idx: slice) -> List[str]: # noqa: D105 + ... + + def __getitem__(self, idx: Union[int, slice]) -> Union[str, List[str]]: """Get a buffer line or slice by integer index. Indexes may be negative to specify positions from the end of the @@ -43,7 +77,19 @@ def __getitem__(self, idx): end = adjust_index(idx.stop, -1) return self.request('nvim_buf_get_lines', start, end, False) - def __setitem__(self, idx, item): + @overload + def __setitem__(self, idx: int, item: Optional[str]) -> None: # noqa: D105 + ... + + @overload + def __setitem__( # noqa: D105 + self, idx: slice, item: Optional[Union[List[str], str]] + ) -> None: + ... + + def __setitem__( + self, idx: Union[int, slice], item: Union[None, str, List[str]] + ) -> None: """Replace a buffer line or slice by integer index. Like with `__getitem__`, indexes may be negative. @@ -52,15 +98,21 @@ def __setitem__(self, idx, item): the whole buffer. """ if not isinstance(idx, slice): + assert not isinstance(item, list) i = adjust_index(idx) lines = [item] if item is not None else [] return self.request('nvim_buf_set_lines', i, i + 1, True, lines) - lines = item if item is not None else [] + if item is None: + lines = [] + elif isinstance(item, str): + lines = [item] + else: + lines = item start = adjust_index(idx.start, 0) end = adjust_index(idx.stop, -1) return self.request('nvim_buf_set_lines', start, end, False, lines) - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Iterate lines of a buffer. This will retrieve all lines locally before iteration starts. This @@ -72,51 +124,81 @@ def __iter__(self): for line in lines: yield line - def __delitem__(self, idx): + def __delitem__(self, idx: Union[int, slice]) -> None: """Delete line or slice of lines from the buffer. This is the same as __setitem__(idx, []) """ self.__setitem__(idx, None) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: """Test inequality of Buffers. Necessary for Python 2 compatibility. """ return not self.__eq__(other) - def append(self, lines, index=-1): + def append( + self, lines: Union[str, bytes, List[Union[str, bytes]]], index: int = -1 + ) -> None: """Append a string or list of lines to the buffer.""" if isinstance(lines, (str, bytes)): lines = [lines] return self.request('nvim_buf_set_lines', index, index, True, lines) - def mark(self, name): + def mark(self, name: str) -> Tuple[int, int]: """Return (row, col) tuple for a named mark.""" - return self.request('nvim_buf_get_mark', name) + return cast(Tuple[int, int], tuple(self.request('nvim_buf_get_mark', name))) - def range(self, start, end): + def range(self, start: int, end: int) -> "Range": """Return a `Range` object, which represents part of the Buffer.""" return Range(self, start, end) - def add_highlight(self, hl_group, line, col_start=0, - col_end=-1, src_id=-1, async_=None, - **kwargs): + def add_highlight( + self, + hl_group: str, + line: int, + col_start: int = 0, + col_end: int = -1, + src_id: int = -1, + async_: bool = None, + **kwargs: Any + ) -> int: """Add a highlight to the buffer.""" async_ = check_async(async_, kwargs, src_id != 0) - return self.request('nvim_buf_add_highlight', src_id, hl_group, - line, col_start, col_end, async_=async_) - - def clear_highlight(self, src_id, line_start=0, line_end=-1, async_=None, - **kwargs): + return self.request( + "nvim_buf_add_highlight", + src_id, + hl_group, + line, + col_start, + col_end, + async_=async_, + ) + + def clear_highlight( + self, + src_id: int, + line_start: int = 0, + line_end: int = -1, + async_: bool = None, + **kwargs: Any + ) -> None: """Clear highlights from the buffer.""" async_ = check_async(async_, kwargs, True) - self.request('nvim_buf_clear_highlight', src_id, - line_start, line_end, async_=async_) - - def update_highlights(self, src_id, hls, clear_start=0, clear_end=-1, - clear=False, async_=True): + self.request( + "nvim_buf_clear_highlight", src_id, line_start, line_end, async_=async_ + ) + + def update_highlights( + self, + src_id: int, + hls: List[Union[Tuple[str, int], Tuple[str, int, int, int]]], + clear_start: Optional[int] = None, + clear_end: int = -1, + clear: bool = False, + async_: bool = True, + ) -> None: """Add or update highlights in batch to avoid unnecessary redraws. A `src_id` must have been allocated prior to use of this function. Use @@ -135,40 +217,52 @@ def update_highlights(self, src_id, hls, clear_start=0, clear_end=-1, if clear and clear_start is None: clear_start = 0 lua = self._session._get_lua_private() - lua.update_highlights(self, src_id, hls, clear_start, clear_end, - async_=async_) + lua.update_highlights(self, src_id, hls, clear_start, clear_end, async_=async_) @property - def name(self): + def name(self) -> str: """Get the buffer name.""" return self.request('nvim_buf_get_name') @name.setter - def name(self, value): + def name(self, value: str) -> None: """Set the buffer name. BufFilePre/BufFilePost are triggered.""" return self.request('nvim_buf_set_name', value) @property - def valid(self): + def valid(self) -> bool: """Return True if the buffer still exists.""" return self.request('nvim_buf_is_valid') @property - def number(self): + def loaded(self) -> bool: + """Return True if the buffer is valid and loaded.""" + return self.request('nvim_buf_is_loaded') + + @property + def number(self) -> int: """Get the buffer number.""" return self.handle class Range(object): - def __init__(self, buffer, start, end): + def __init__(self, buffer: Buffer, start: int, end: int): self._buffer = buffer self.start = start - 1 self.end = end - 1 - def __len__(self): + def __len__(self) -> int: return self.end - self.start + 1 - def __getitem__(self, idx): + @overload + def __getitem__(self, idx: int) -> str: + ... + + @overload + def __getitem__(self, idx: slice) -> List[str]: + ... + + def __getitem__(self, idx: Union[int, slice]) -> Union[str, List[str]]: if not isinstance(idx, slice): return self._buffer[self._normalize_index(idx)] start = self._normalize_index(idx.start) @@ -179,8 +273,19 @@ def __getitem__(self, idx): end = self.end + 1 return self._buffer[start:end] - def __setitem__(self, idx, lines): + @overload + def __setitem__(self, idx: int, lines: Optional[str]) -> None: + ... + + @overload + def __setitem__(self, idx: slice, lines: Optional[List[str]]) -> None: + ... + + def __setitem__( + self, idx: Union[int, slice], lines: Union[None, str, List[str]] + ) -> None: if not isinstance(idx, slice): + assert not isinstance(lines, list) self._buffer[self._normalize_index(idx)] = lines return start = self._normalize_index(idx.start) @@ -191,17 +296,27 @@ def __setitem__(self, idx, lines): end = self.end self._buffer[start:end + 1] = lines - def __iter__(self): + def __iter__(self) -> Iterator[str]: for i in range(self.start, self.end + 1): yield self._buffer[i] - def append(self, lines, i=None): + def append( + self, lines: Union[str, bytes, List[Union[str, bytes]]], i: int = None + ) -> None: i = self._normalize_index(i) if i is None: i = self.end + 1 self._buffer.append(lines, i) - def _normalize_index(self, index): + @overload + def _normalize_index(self, index: int) -> int: + ... + + @overload + def _normalize_index(self, index: None) -> None: + ... + + def _normalize_index(self, index: Optional[int]) -> Optional[int]: if index is None: return None if index < 0: diff --git a/pynvim/api/common.py b/pynvim/api/common.py index 21f5b2da..101df017 100644 --- a/pynvim/api/common.py +++ b/pynvim/api/common.py @@ -1,18 +1,35 @@ """Code shared between the API classes.""" import functools +import sys +from abc import ABC, abstractmethod +from typing import (Any, Callable, Generic, Iterator, List, Optional, Tuple, TypeVar, + Union, overload) from msgpack import unpackb +if sys.version_info < (3, 8): + from typing_extensions import Literal, Protocol +else: + from typing import Literal, Protocol from pynvim.compat import unicode_errors_default __all__ = () +T = TypeVar('T') +TDecodeMode = Union[Literal[True], str] + + class NvimError(Exception): pass -class Remote(object): +class IRemote(Protocol): + def request(self, name: str, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError + + +class Remote(ABC): """Base class for Nvim objects(buffer/window/tabpage). @@ -21,7 +38,7 @@ class Remote(object): object handle into consideration. """ - def __init__(self, session, code_data): + def __init__(self, session: IRemote, code_data: Tuple[int, Any]): """Initialize from session and code_data immutable object. The `code_data` contains serialization information required for @@ -37,23 +54,28 @@ def __init__(self, session, code_data): self.options = RemoteMap(self, self._api_prefix + 'get_option', self._api_prefix + 'set_option') - def __repr__(self): + @property + @abstractmethod + def _api_prefix(self) -> str: + raise NotImplementedError() + + def __repr__(self) -> str: """Get text representation of the object.""" return '<%s(handle=%r)>' % ( self.__class__.__name__, self.handle, ) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: """Return True if `self` and `other` are the same object.""" return (hasattr(other, 'code_data') and other.code_data == self.code_data) - def __hash__(self): + def __hash__(self) -> int: """Return hash based on remote object id.""" return self.code_data.__hash__() - def request(self, name, *args, **kwargs): + def request(self, name: str, *args: Any, **kwargs: Any) -> Any: """Wrapper for nvim.request.""" return self._session.request(name, self, *args, **kwargs) @@ -62,17 +84,20 @@ class RemoteApi(object): """Wrapper to allow api methods to be called like python methods.""" - def __init__(self, obj, api_prefix): + def __init__(self, obj: IRemote, api_prefix: str): """Initialize a RemoteApi with object and api prefix.""" self._obj = obj self._api_prefix = api_prefix - def __getattr__(self, name): + def __getattr__(self, name: str) -> Callable[..., Any]: """Return wrapper to named api method.""" return functools.partial(self._obj.request, self._api_prefix + name) -def transform_keyerror(exc): +E = TypeVar('E', bound=Exception) + + +def transform_keyerror(exc: E) -> Union[E, KeyError]: if isinstance(exc, NvimError): if exc.args[0].startswith('Key not found:'): return KeyError(exc.args[0]) @@ -94,7 +119,13 @@ class RemoteMap(object): _set = None _del = None - def __init__(self, obj, get_method, set_method=None, del_method=None): + def __init__( + self, + obj: IRemote, + get_method: str, + set_method: str = None, + del_method: str = None + ): """Initialize a RemoteMap with session, getter/setter.""" self._get = functools.partial(obj.request, get_method) if set_method: @@ -102,20 +133,20 @@ def __init__(self, obj, get_method, set_method=None, del_method=None): if del_method: self._del = functools.partial(obj.request, del_method) - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: """Return a map value by key.""" try: return self._get(key) except NvimError as exc: raise transform_keyerror(exc) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: """Set a map value by key(if the setter was provided).""" if not self._set: raise TypeError('This dict is read-only') self._set(key, value) - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: """Delete a map value by associating None with the key.""" if not self._del: raise TypeError('This dict is read-only') @@ -124,7 +155,7 @@ def __delitem__(self, key): except NvimError as exc: raise transform_keyerror(exc) - def __contains__(self, key): + def __contains__(self, key: str) -> bool: """Check if key is present in the map.""" try: self._get(key) @@ -132,7 +163,13 @@ def __contains__(self, key): except Exception: return False - def get(self, key, default=None): + @overload + def get(self, key: str, default: T) -> T: ... + + @overload + def get(self, key: str, default: Optional[T] = None) -> Optional[T]: ... + + def get(self, key: str, default: Optional[T] = None) -> Optional[T]: """Return value for key if present, else a default value.""" try: return self.__getitem__(key) @@ -140,11 +177,11 @@ def get(self, key, default=None): return default -class RemoteSequence(object): +class RemoteSequence(Generic[T]): """Represents a sequence of objects stored in Nvim. - This class is used to wrap msgapck-rpc functions that work on Nvim + This class is used to wrap msgpack-rpc functions that work on Nvim sequences(of lines, buffers, windows and tabpages) with an API that is similar to the one provided by the python-vim interface. @@ -157,36 +194,46 @@ class RemoteSequence(object): locally(iteration, indexing, counting, etc). """ - def __init__(self, session, method): + def __init__(self, session: IRemote, method: str): """Initialize a RemoteSequence with session, method.""" self._fetch = functools.partial(session.request, method) - def __len__(self): + def __len__(self) -> int: """Return the length of the remote sequence.""" return len(self._fetch()) - def __getitem__(self, idx): + @overload + def __getitem__(self, idx: int) -> T: ... + + @overload + def __getitem__(self, idx: slice) -> List[T]: ... + + def __getitem__(self, idx: Union[slice, int]) -> Union[T, List[T]]: """Return a sequence item by index.""" if not isinstance(idx, slice): return self._fetch()[idx] return self._fetch()[idx.start:idx.stop] - def __iter__(self): + def __iter__(self) -> Iterator[T]: """Return an iterator for the sequence.""" items = self._fetch() for item in items: yield item - def __contains__(self, item): + def __contains__(self, item: T) -> bool: """Check if an item is present in the sequence.""" return item in self._fetch() -def _identity(obj, session, method, kind): - return obj +@overload +def decode_if_bytes(obj: bytes, mode: TDecodeMode = True) -> str: ... + + +@overload +def decode_if_bytes(obj: T, mode: TDecodeMode = True) -> Union[T, str]: ... -def decode_if_bytes(obj, mode=True): +def decode_if_bytes(obj: T, mode: TDecodeMode = True) -> Union[T, str]: """Decode obj if it is bytes.""" if mode is True: mode = unicode_errors_default @@ -195,7 +242,7 @@ def decode_if_bytes(obj, mode=True): return obj -def walk(fn, obj, *args, **kwargs): +def walk(fn: Callable[..., Any], obj: Any, *args: Any, **kwargs: Any) -> Any: """Recursively walk an object graph applying `fn`/`args` to objects.""" if type(obj) in [list, tuple]: return list(walk(fn, o, *args) for o in obj) diff --git a/pynvim/api/nvim.py b/pynvim/api/nvim.py index 3a26c2a4..ce9c33f7 100644 --- a/pynvim/api/nvim.py +++ b/pynvim/api/nvim.py @@ -5,17 +5,27 @@ from functools import partial from traceback import format_stack from types import SimpleNamespace +from typing import (Any, AnyStr, Callable, Dict, Iterator, List, Optional, + TYPE_CHECKING, Union) from msgpack import ExtType from pynvim.api.buffer import Buffer from pynvim.api.common import (NvimError, Remote, RemoteApi, RemoteMap, RemoteSequence, - decode_if_bytes, walk) + TDecodeMode, decode_if_bytes, walk) from pynvim.api.tabpage import Tabpage from pynvim.api.window import Window from pynvim.util import format_exc_skip -__all__ = ('Nvim') +if TYPE_CHECKING: + from pynvim.msgpack_rpc import Session + +if sys.version_info < (3, 8): + from typing_extensions import Literal +else: + from typing import Literal + +__all__ = ['Nvim'] os_chdir = os.chdir @@ -69,7 +79,7 @@ class Nvim(object): """ @classmethod - def from_session(cls, session): + def from_session(cls, session: 'Session') -> 'Nvim': """Create a new Nvim instance for a Session instance. This method must be called to create the first Nvim instance, since it @@ -90,13 +100,20 @@ def from_session(cls, session): return cls(session, channel_id, metadata, types) @classmethod - def from_nvim(cls, nvim): + def from_nvim(cls, nvim: 'Nvim') -> 'Nvim': """Create a new Nvim instance from an existing instance.""" return cls(nvim._session, nvim.channel_id, nvim.metadata, nvim.types, nvim._decode, nvim._err_cb) - def __init__(self, session, channel_id, metadata, types, - decode=False, err_cb=None): + def __init__( + self, + session: 'Session', + channel_id: int, + metadata: Dict[str, Any], + types: Dict[int, Any], + decode: TDecodeMode = True, + err_cb: Callable[[str], None] = None + ): """Initialize a new Nvim instance. This method is module-private.""" self._session = session self.channel_id = channel_id @@ -109,18 +126,23 @@ def __init__(self, session, channel_id, metadata, types, self.vvars = RemoteMap(self, 'nvim_get_vvar', None, None) self.options = RemoteMap(self, 'nvim_get_option', 'nvim_set_option') self.buffers = Buffers(self) - self.windows = RemoteSequence(self, 'nvim_list_wins') - self.tabpages = RemoteSequence(self, 'nvim_list_tabpages') + self.windows: RemoteSequence[Window] = RemoteSequence(self, 'nvim_list_wins') + self.tabpages: RemoteSequence[Tabpage] = RemoteSequence( + self, 'nvim_list_tabpages' + ) self.current = Current(self) self.session = CompatibilitySession(self) self.funcs = Funcs(self) self.lua = LuaFuncs(self) self.error = NvimError self._decode = decode - self._err_cb = err_cb + if err_cb is None: + self._err_cb: Callable[[str], Any] = lambda _: None + else: + self._err_cb = err_cb self.loop = self._session.loop._loop - def _from_nvim(self, obj, decode=None): + def _from_nvim(self, obj: Any, decode: TDecodeMode = None) -> Any: if decode is None: decode = self._decode if type(obj) is ExtType: @@ -130,18 +152,18 @@ def _from_nvim(self, obj, decode=None): obj = decode_if_bytes(obj, decode) return obj - def _to_nvim(self, obj): + def _to_nvim(self, obj: Any) -> Any: if isinstance(obj, Remote): return ExtType(*obj.code_data) return obj - def _get_lua_private(self): + def _get_lua_private(self) -> 'LuaFuncs': if not getattr(self._session, "_has_lua", False): self.exec_lua(lua_module, self.channel_id) - self._session._has_lua = True + self._session._has_lua = True # type: ignore[attr-defined] return getattr(self.lua, "_pynvim_{}".format(self.channel_id)) - def request(self, name, *args, **kwargs): + def request(self, name: str, *args: Any, **kwargs: Any) -> Any: r"""Send an API request or notification to nvim. It is rarely needed to call this function directly, as most API @@ -177,7 +199,7 @@ def request(self, name, *args, **kwargs): res = self._session.request(name, *args, **kwargs) return walk(self._from_nvim, res, decode=decode) - def next_message(self): + def next_message(self) -> Any: """Block until a message(request or notification) is available. If any messages were previously enqueued, return the first in queue. @@ -187,8 +209,13 @@ def next_message(self): if msg: return walk(self._from_nvim, msg) - def run_loop(self, request_cb, notification_cb, - setup_cb=None, err_cb=None): + def run_loop( + self, + request_cb: Optional[Callable[[str, List[Any]], Any]], + notification_cb: Optional[Callable[[str, List[Any]], Any]], + setup_cb: Callable[[], None] = None, + err_cb: Callable[[str], Any] = None + ) -> None: """Run the event loop to receive requests and notifications from Nvim. This should not be called from a plugin running in the host, which @@ -198,11 +225,11 @@ def run_loop(self, request_cb, notification_cb, err_cb = sys.stderr.write self._err_cb = err_cb - def filter_request_cb(name, args): + def filter_request_cb(name: str, args: Any) -> Any: name = self._from_nvim(name) args = walk(self._from_nvim, args) try: - result = request_cb(name, args) + result = request_cb(name, args) # type: ignore[misc] except Exception: msg = ("error caught in request handler '{} {}'\n{}\n\n" .format(name, args, format_exc_skip(1))) @@ -210,11 +237,11 @@ def filter_request_cb(name, args): raise return walk(self._to_nvim, result) - def filter_notification_cb(name, args): + def filter_notification_cb(name: str, args: Any) -> None: name = self._from_nvim(name) args = walk(self._from_nvim, args) try: - notification_cb(name, args) + notification_cb(name, args) # type: ignore[misc] except Exception: msg = ("error caught in notification handler '{} {}'\n{}\n\n" .format(name, args, format_exc_skip(1))) @@ -223,31 +250,33 @@ def filter_notification_cb(name, args): self._session.run(filter_request_cb, filter_notification_cb, setup_cb) - def stop_loop(self): + def stop_loop(self) -> None: """Stop the event loop being started with `run_loop`.""" self._session.stop() - def close(self): + def close(self) -> None: """Close the nvim session and release its resources.""" self._session.close() - def __enter__(self): + def __enter__(self) -> 'Nvim': """Enter nvim session as a context manager.""" return self - def __exit__(self, *exc_info): + def __exit__(self, *exc_info: Any) -> None: """Exit nvim session as a context manager. Closes the event loop. """ self.close() - def with_decode(self, decode=True): + def with_decode(self, decode: Literal[True] = True) -> 'Nvim': """Initialize a new Nvim instance.""" return Nvim(self._session, self.channel_id, self.metadata, self.types, decode, self._err_cb) - def ui_attach(self, width, height, rgb=None, **kwargs): + def ui_attach( + self, width: int, height: int, rgb: bool = None, **kwargs: Any + ) -> None: """Register as a remote UI. After this method is called, the client will receive redraw @@ -258,42 +287,42 @@ def ui_attach(self, width, height, rgb=None, **kwargs): options['rgb'] = rgb return self.request('nvim_ui_attach', width, height, options) - def ui_detach(self): + def ui_detach(self) -> None: """Unregister as a remote UI.""" return self.request('nvim_ui_detach') - def ui_try_resize(self, width, height): + def ui_try_resize(self, width: int, height: int) -> None: """Notify nvim that the client window has resized. If possible, nvim will send a redraw request to resize. """ return self.request('ui_try_resize', width, height) - def subscribe(self, event): + def subscribe(self, event: str) -> None: """Subscribe to a Nvim event.""" return self.request('nvim_subscribe', event) - def unsubscribe(self, event): + def unsubscribe(self, event: str) -> None: """Unsubscribe to a Nvim event.""" return self.request('nvim_unsubscribe', event) - def command(self, string, **kwargs): + def command(self, string: str, **kwargs: Any) -> None: """Execute a single ex command.""" return self.request('nvim_command', string, **kwargs) - def command_output(self, string): + def command_output(self, string: str) -> str: """Execute a single ex command and return the output.""" return self.request('nvim_command_output', string) - def eval(self, string, **kwargs): + def eval(self, string: str, **kwargs: Any) -> Any: """Evaluate a vimscript expression.""" return self.request('nvim_eval', string, **kwargs) - def call(self, name, *args, **kwargs): + def call(self, name: str, *args: Any, **kwargs: Any) -> Any: """Call a vimscript function.""" return self.request('nvim_call_function', name, args, **kwargs) - def exec_lua(self, code, *args, **kwargs): + def exec_lua(self, code: str, *args: Any, **kwargs: Any) -> Any: """Execute lua code. Additional parameters are available as `...` inside the lua chunk. @@ -314,18 +343,18 @@ def exec_lua(self, code, *args, **kwargs): """ return self.request('nvim_execute_lua', code, args, **kwargs) - def strwidth(self, string): + def strwidth(self, string: str) -> int: """Return the number of display cells `string` occupies. Tab is counted as one cell. """ return self.request('nvim_strwidth', string) - def list_runtime_paths(self): + def list_runtime_paths(self) -> List[str]: """Return a list of paths contained in the 'runtimepath' option.""" return self.request('nvim_list_runtime_paths') - def foreach_rtp(self, cb): + def foreach_rtp(self, cb: Callable[[str], Any]) -> None: """Invoke `cb` for each path in 'runtimepath'. Call the given callable for each path in 'runtimepath' until either @@ -333,19 +362,19 @@ def foreach_rtp(self, cb): are no longer paths. If stopped in case callable returned non-None, vim.foreach_rtp function returns the value returned by callable. """ - for path in self.request('nvim_list_runtime_paths'): + for path in self.list_runtime_paths(): try: if cb(path) is not None: break except Exception: break - def chdir(self, dir_path): + def chdir(self, dir_path: str) -> None: """Run os.chdir, then all appropriate vim stuff.""" os_chdir(dir_path) return self.request('nvim_set_current_dir', dir_path) - def feedkeys(self, keys, options='', escape_csi=True): + def feedkeys(self, keys: str, options: str = '', escape_csi: bool = True) -> None: """Push `keys` to Nvim user input buffer. Options can be a string with the following character flags: @@ -356,7 +385,7 @@ def feedkeys(self, keys, options='', escape_csi=True): """ return self.request('nvim_feedkeys', keys, options, escape_csi) - def input(self, bytes): + def input(self, bytes: AnyStr) -> int: """Push `bytes` to Nvim low level input buffer. Unlike `feedkeys()`, this uses the lowest level input buffer and the @@ -366,8 +395,13 @@ def input(self, bytes): """ return self.request('nvim_input', bytes) - def replace_termcodes(self, string, from_part=False, do_lt=True, - special=True): + def replace_termcodes( + self, + string: str, + from_part: bool = False, + do_lt: bool = True, + special: bool = True + ) -> str: r"""Replace any terminal code strings by byte sequences. The returned sequences are Nvim's internal representation of keys, @@ -383,14 +417,14 @@ def replace_termcodes(self, string, from_part=False, do_lt=True, return self.request('nvim_replace_termcodes', string, from_part, do_lt, special) - def out_write(self, msg, **kwargs): + def out_write(self, msg: str, **kwargs: Any) -> None: r"""Print `msg` as a normal message. The message is buffered (won't display) until linefeed ("\n"). """ return self.request('nvim_out_write', msg, **kwargs) - def err_write(self, msg, **kwargs): + def err_write(self, msg: str, **kwargs: Any) -> None: r"""Print `msg` as an error message. The message is buffered (won't display) until linefeed ("\n"). @@ -403,11 +437,11 @@ def err_write(self, msg, **kwargs): return return self.request('nvim_err_write', msg, **kwargs) - def _thread_invalid(self): + def _thread_invalid(self) -> bool: return (self._session._loop_thread is not None and threading.current_thread() != self._session._loop_thread) - def quit(self, quit_command='qa!'): + def quit(self, quit_command: str = 'qa!') -> None: """Send a quit command to Nvim. By default, the quit command is 'qa!' which will make Nvim quit without @@ -421,11 +455,11 @@ def quit(self, quit_command='qa!'): # ignore it. pass - def new_highlight_source(self): + def new_highlight_source(self) -> int: """Return new src_id for use with Buffer.add_highlight.""" return self.current.buffer.add_highlight("", 0, src_id=0) - def async_call(self, fn, *args, **kwargs): + def async_call(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: """Schedule `fn` to be called by the event loop soon. This function is thread-safe, and is the only way code not @@ -437,7 +471,7 @@ def async_call(self, fn, *args, **kwargs): """ call_point = ''.join(format_stack(None, 5)[:-1]) - def handler(): + def handler() -> None: try: fn(*args, **kwargs) except Exception as err: @@ -460,26 +494,26 @@ class Buffers(object): Conforms to *python-buffers*. """ - def __init__(self, nvim): + def __init__(self, nvim: Nvim): """Initialize a Buffers object with Nvim object `nvim`.""" self._fetch_buffers = nvim.api.list_bufs - def __len__(self): + def __len__(self) -> int: """Return the count of buffers.""" return len(self._fetch_buffers()) - def __getitem__(self, number): + def __getitem__(self, number: int) -> Buffer: """Return the Buffer object matching buffer number `number`.""" for b in self._fetch_buffers(): if b.number == number: return b raise KeyError(number) - def __contains__(self, b): + def __contains__(self, b: Buffer) -> bool: """Return whether Buffer `b` is a known valid buffer.""" return isinstance(b, Buffer) and b.valid - def __iter__(self): + def __iter__(self) -> Iterator[Buffer]: """Return an iterator over the list of buffers.""" return iter(self._fetch_buffers()) @@ -488,7 +522,7 @@ class CompatibilitySession(object): """Helper class for API compatibility.""" - def __init__(self, nvim): + def __init__(self, nvim: Nvim): self.threadsafe_call = nvim.async_call @@ -496,44 +530,44 @@ class Current(object): """Helper class for emulating vim.current from python-vim.""" - def __init__(self, session): + def __init__(self, session: Nvim): self._session = session self.range = None @property - def line(self): + def line(self) -> str: return self._session.request('nvim_get_current_line') @line.setter - def line(self, line): + def line(self, line: str) -> None: return self._session.request('nvim_set_current_line', line) @line.deleter - def line(self): + def line(self) -> None: return self._session.request('nvim_del_current_line') @property - def buffer(self): + def buffer(self) -> Buffer: return self._session.request('nvim_get_current_buf') @buffer.setter - def buffer(self, buffer): + def buffer(self, buffer: Union[Buffer, int]) -> None: return self._session.request('nvim_set_current_buf', buffer) @property - def window(self): + def window(self) -> Window: return self._session.request('nvim_get_current_win') @window.setter - def window(self, window): + def window(self, window: Union[Window, int]) -> None: return self._session.request('nvim_set_current_win', window) @property - def tabpage(self): + def tabpage(self) -> Tabpage: return self._session.request('nvim_get_current_tabpage') @tabpage.setter - def tabpage(self, tabpage): + def tabpage(self, tabpage: Union[Tabpage, int]) -> None: return self._session.request('nvim_set_current_tabpage', tabpage) @@ -541,10 +575,10 @@ class Funcs(object): """Helper class for functional vimscript interface.""" - def __init__(self, nvim): + def __init__(self, nvim: Nvim): self._nvim = nvim - def __getattr__(self, name): + def __getattr__(self, name: str) -> Callable[..., Any]: return partial(self._nvim.call, name) @@ -552,16 +586,16 @@ class LuaFuncs(object): """Wrapper to allow lua functions to be called like python methods.""" - def __init__(self, nvim, name=""): + def __init__(self, nvim: Nvim, name: str = ""): self._nvim = nvim self.name = name - def __getattr__(self, name): + def __getattr__(self, name: str) -> 'LuaFuncs': """Return wrapper to named api method.""" prefix = self.name + "." if self.name else "" return LuaFuncs(self._nvim, prefix + name) - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: # first new function after keyword rename, be a bit noisy if 'async' in kwargs: raise ValueError('"async" argument is not allowed. ' diff --git a/pynvim/api/tabpage.py b/pynvim/api/tabpage.py index f8a6bba5..244be9d9 100644 --- a/pynvim/api/tabpage.py +++ b/pynvim/api/tabpage.py @@ -1,8 +1,13 @@ """API for working with Nvim tabpages.""" +from typing import Any, TYPE_CHECKING, Tuple + from pynvim.api.common import Remote, RemoteSequence +from pynvim.api.window import Window +if TYPE_CHECKING: + from pynvim.api.nvim import Nvim -__all__ = ('Tabpage') +__all__ = ['Tabpage'] class Tabpage(Remote): @@ -10,26 +15,28 @@ class Tabpage(Remote): _api_prefix = "nvim_tabpage_" - def __init__(self, *args): + def __init__(self, session: 'Nvim', code_data: Tuple[int, Any]): """Initialize from session and code_data immutable object. The `code_data` contains serialization information required for msgpack-rpc calls. It must be immutable for Buffer equality to work. """ - super(Tabpage, self).__init__(*args) - self.windows = RemoteSequence(self, 'nvim_tabpage_list_wins') + super(Tabpage, self).__init__(session, code_data) + self.windows: RemoteSequence[Window] = RemoteSequence( + self, "nvim_tabpage_list_wins" + ) @property - def window(self): + def window(self) -> Window: """Get the `Window` currently focused on the tabpage.""" return self.request('nvim_tabpage_get_win') @property - def valid(self): + def valid(self) -> bool: """Return True if the tabpage still exists.""" return self.request('nvim_tabpage_is_valid') @property - def number(self): + def number(self) -> int: """Get the tabpage number.""" return self.request('nvim_tabpage_get_number') diff --git a/pynvim/api/window.py b/pynvim/api/window.py index 14810609..d0b82903 100644 --- a/pynvim/api/window.py +++ b/pynvim/api/window.py @@ -1,8 +1,14 @@ """API for working with Nvim windows.""" +from typing import TYPE_CHECKING, Tuple, cast + +from pynvim.api.buffer import Buffer from pynvim.api.common import Remote +if TYPE_CHECKING: + from pynvim.api.tabpage import Tabpage + -__all__ = ('Window') +__all__ = ['Window'] class Window(Remote): @@ -12,61 +18,61 @@ class Window(Remote): _api_prefix = "nvim_win_" @property - def buffer(self): + def buffer(self) -> Buffer: """Get the `Buffer` currently being displayed by the window.""" return self.request('nvim_win_get_buf') @property - def cursor(self): + def cursor(self) -> Tuple[int, int]: """Get the (row, col) tuple with the current cursor position.""" - return self.request('nvim_win_get_cursor') + return cast(Tuple[int, int], tuple(self.request('nvim_win_get_cursor'))) @cursor.setter - def cursor(self, pos): + def cursor(self, pos: Tuple[int, int]) -> None: """Set the (row, col) tuple as the new cursor position.""" return self.request('nvim_win_set_cursor', pos) @property - def height(self): + def height(self) -> int: """Get the window height in rows.""" return self.request('nvim_win_get_height') @height.setter - def height(self, height): + def height(self, height: int) -> None: """Set the window height in rows.""" return self.request('nvim_win_set_height', height) @property - def width(self): + def width(self) -> int: """Get the window width in rows.""" return self.request('nvim_win_get_width') @width.setter - def width(self, width): + def width(self, width: int) -> None: """Set the window height in rows.""" return self.request('nvim_win_set_width', width) @property - def row(self): + def row(self) -> int: """0-indexed, on-screen window position(row) in display cells.""" return self.request('nvim_win_get_position')[0] @property - def col(self): + def col(self) -> int: """0-indexed, on-screen window position(col) in display cells.""" return self.request('nvim_win_get_position')[1] @property - def tabpage(self): + def tabpage(self) -> 'Tabpage': """Get the `Tabpage` that contains the window.""" return self.request('nvim_win_get_tabpage') @property - def valid(self): + def valid(self) -> bool: """Return True if the window still exists.""" return self.request('nvim_win_is_valid') @property - def number(self): + def number(self) -> int: """Get the window number.""" return self.request('nvim_win_get_number') diff --git a/pynvim/compat.py b/pynvim/compat.py index c6028fb8..17ce35f8 100644 --- a/pynvim/compat.py +++ b/pynvim/compat.py @@ -1,10 +1,10 @@ """Code for compatibility across Python versions.""" - import warnings from imp import find_module as original_find_module +from typing import Any, Dict, Optional -def find_module(fullname, path): +def find_module(fullname, path): # type: ignore """Compatibility wrapper for imp.find_module. Automatically decodes arguments of find_module, in Python3 @@ -30,7 +30,7 @@ def find_module(fullname, path): NUM_TYPES = (int, float) -def check_async(async_, kwargs, default): +def check_async(async_: Optional[bool], kwargs: Dict[str, Any], default: bool) -> bool: """Return a value of 'async' in kwargs or default when async_ is None. This helper function exists for backward compatibility (See #274). diff --git a/pynvim/msgpack_rpc/__init__.py b/pynvim/msgpack_rpc/__init__.py index 8da22e04..e4efc706 100644 --- a/pynvim/msgpack_rpc/__init__.py +++ b/pynvim/msgpack_rpc/__init__.py @@ -4,8 +4,10 @@ handling some Nvim particularities(server->client requests for example), the code here should work with other msgpack-rpc servers. """ +from typing import Any, List + from pynvim.msgpack_rpc.async_session import AsyncSession -from pynvim.msgpack_rpc.event_loop import EventLoop +from pynvim.msgpack_rpc.event_loop import EventLoop, TTransportType from pynvim.msgpack_rpc.msgpack_stream import MsgpackStream from pynvim.msgpack_rpc.session import ErrorResponse, Session from pynvim.util import get_client_info @@ -15,7 +17,9 @@ 'ErrorResponse') -def session(transport_type='stdio', *args, **kwargs): +def session( + transport_type: TTransportType = 'stdio', *args: Any, **kwargs: Any +) -> Session: loop = EventLoop(transport_type, *args, **kwargs) msgpack_stream = MsgpackStream(loop) async_session = AsyncSession(msgpack_stream) @@ -25,21 +29,21 @@ def session(transport_type='stdio', *args, **kwargs): return session -def tcp_session(address, port=7450): +def tcp_session(address: str, port: int = 7450) -> Session: """Create a msgpack-rpc session from a tcp address/port.""" return session('tcp', address, port) -def socket_session(path): +def socket_session(path: str) -> Session: """Create a msgpack-rpc session from a unix domain socket.""" return session('socket', path) -def stdio_session(): +def stdio_session() -> Session: """Create a msgpack-rpc session from stdin/stdout.""" return session('stdio') -def child_session(argv): +def child_session(argv: List[str]) -> Session: """Create a msgpack-rpc session from a new Nvim instance.""" return session('child', argv) diff --git a/pynvim/msgpack_rpc/event_loop/__init__.py b/pynvim/msgpack_rpc/event_loop/__init__.py index 84cf0812..e94cdbfe 100644 --- a/pynvim/msgpack_rpc/event_loop/__init__.py +++ b/pynvim/msgpack_rpc/event_loop/__init__.py @@ -4,6 +4,7 @@ """ from pynvim.msgpack_rpc.event_loop.asyncio import AsyncioEventLoop as EventLoop +from pynvim.msgpack_rpc.event_loop.base import TTransportType -__all__ = ['EventLoop'] +__all__ = ['EventLoop', 'TTransportType'] diff --git a/pynvim/msgpack_rpc/event_loop/asyncio.py b/pynvim/msgpack_rpc/event_loop/asyncio.py index de3cb0ca..b7843cdf 100644 --- a/pynvim/msgpack_rpc/event_loop/asyncio.py +++ b/pynvim/msgpack_rpc/event_loop/asyncio.py @@ -13,6 +13,8 @@ import os import sys from collections import deque +from signal import Signals +from typing import Any, Callable, Deque, List from pynvim.msgpack_rpc.event_loop.base import BaseEventLoop @@ -27,14 +29,15 @@ # On windows use ProactorEventLoop which support pipes and is backed by the # more powerful IOCP facility # NOTE: we override in the stdio case, because it doesn't work. - loop_cls = asyncio.ProactorEventLoop + loop_cls = asyncio.ProactorEventLoop # type: ignore[attr-defined,misc] class AsyncioEventLoop(BaseEventLoop, asyncio.Protocol, asyncio.SubprocessProtocol): - """`BaseEventLoop` subclass that uses `asyncio` as a backend.""" + _queued_data: Deque[bytes] + def connection_made(self, transport): """Used to signal `asyncio.Protocol` of a successful connection.""" self._transport = transport @@ -46,7 +49,7 @@ def connection_lost(self, exc): """Used to signal `asyncio.Protocol` of a lost connection.""" self._on_error(exc.args[0] if exc else 'EOF') - def data_received(self, data): + def data_received(self, data: bytes) -> None: """Used to signal `asyncio.Protocol` of incoming data.""" if self._on_data: self._on_data(data) @@ -66,30 +69,34 @@ def pipe_data_received(self, fd, data): else: self._queued_data.append(data) - def process_exited(self): + def process_exited(self) -> None: """Used to signal `asyncio.SubprocessProtocol` when the child exits.""" self._on_error('EOF') - def _init(self): + def _init(self) -> None: self._loop = loop_cls() self._queued_data = deque() self._fact = lambda: self self._raw_transport = None - def _connect_tcp(self, address, port): + def _connect_tcp(self, address: str, port: int) -> None: coroutine = self._loop.create_connection(self._fact, address, port) self._loop.run_until_complete(coroutine) - def _connect_socket(self, path): + def _connect_socket(self, path: str) -> None: if os.name == 'nt': - coroutine = self._loop.create_pipe_connection(self._fact, path) + coroutine = self._loop.create_pipe_connection( # type: ignore[attr-defined] + self._fact, path + ) else: coroutine = self._loop.create_unix_connection(self._fact, path) self._loop.run_until_complete(coroutine) - def _connect_stdio(self): + def _connect_stdio(self) -> None: if os.name == 'nt': - pipe = PipeHandle(msvcrt.get_osfhandle(sys.stdin.fileno())) + pipe: Any = PipeHandle( + msvcrt.get_osfhandle(sys.stdin.fileno()) # type: ignore[attr-defined] + ) else: pipe = sys.stdin coroutine = self._loop.connect_read_pipe(self._fact, pipe) @@ -102,43 +109,47 @@ def _connect_stdio(self): os.dup2(sys.stderr.fileno(), sys.stdout.fileno()) if os.name == 'nt': - pipe = PipeHandle(msvcrt.get_osfhandle(rename_stdout)) + pipe = PipeHandle( + msvcrt.get_osfhandle(rename_stdout) # type: ignore[attr-defined] + ) else: pipe = os.fdopen(rename_stdout, 'wb') coroutine = self._loop.connect_write_pipe(self._fact, pipe) self._loop.run_until_complete(coroutine) debug("native stdout connection successful") - def _connect_child(self, argv): + def _connect_child(self, argv: List[str]) -> None: if os.name != 'nt': self._child_watcher = asyncio.get_child_watcher() self._child_watcher.attach_loop(self._loop) coroutine = self._loop.subprocess_exec(self._fact, *argv) self._loop.run_until_complete(coroutine) - def _start_reading(self): + def _start_reading(self) -> None: pass - def _send(self, data): + def _send(self, data: bytes) -> None: self._transport.write(data) - def _run(self): + def _run(self) -> None: while self._queued_data: - self._on_data(self._queued_data.popleft()) + data = self._queued_data.popleft() + if self._on_data is not None: + self._on_data(data) self._loop.run_forever() - def _stop(self): + def _stop(self) -> None: self._loop.stop() - def _close(self): + def _close(self) -> None: if self._raw_transport is not None: self._raw_transport.close() self._loop.close() - def _threadsafe_call(self, fn): + def _threadsafe_call(self, fn: Callable[[], Any]) -> None: self._loop.call_soon_threadsafe(fn) - def _setup_signals(self, signals): + def _setup_signals(self, signals: List[Signals]) -> None: if os.name == 'nt': # add_signal_handler is not supported in win32 self._signals = [] @@ -148,6 +159,6 @@ def _setup_signals(self, signals): for signum in self._signals: self._loop.add_signal_handler(signum, self._on_signal, signum) - def _teardown_signals(self): + def _teardown_signals(self) -> None: for signum in self._signals: self._loop.remove_signal_handler(signum) diff --git a/pynvim/msgpack_rpc/event_loop/base.py b/pynvim/msgpack_rpc/event_loop/base.py index e24f3f8d..86fde9c2 100644 --- a/pynvim/msgpack_rpc/event_loop/base.py +++ b/pynvim/msgpack_rpc/event_loop/base.py @@ -1,8 +1,15 @@ """Common code for event loop implementations.""" import logging import signal +import sys import threading +from abc import ABC, abstractmethod +from typing import Any, Callable, List, Optional, Type, Union +if sys.version_info < (3, 8): + from typing_extensions import Literal +else: + from typing import Literal logger = logging.getLogger(__name__) debug, info, warn = (logger.debug, logger.info, logger.warning,) @@ -14,8 +21,15 @@ default_int_handler = signal.getsignal(signal.SIGINT) main_thread = threading.current_thread() +TTransportType = Union[ + Literal['stdio'], + Literal['socket'], + Literal['tcp'], + Literal['child'] +] -class BaseEventLoop(object): + +class BaseEventLoop(ABC): """Abstract base class for all event loops. @@ -52,7 +66,7 @@ class BaseEventLoop(object): - `_teardown_signals()`: Removes signal listeners set by `_setup_signals` """ - def __init__(self, transport_type, *args): + def __init__(self, transport_type: TTransportType, *args: Any, **kwargs: Any): """Initialize and connect the event loop instance. The only arguments are the transport type and transport-specific @@ -83,37 +97,65 @@ def __init__(self, transport_type, *args): self._transport_type = transport_type self._signames = dict((k, v) for v, k in signal.__dict__.items() if v.startswith('SIG')) - self._on_data = None - self._error = None + self._on_data: Optional[Callable[[bytes], None]] = None + self._error: Optional[BaseException] = None self._init() try: - getattr(self, '_connect_{}'.format(transport_type))(*args) + getattr(self, '_connect_{}'.format(transport_type))(*args, **kwargs) except Exception as e: self.close() raise e self._start_reading() - def connect_tcp(self, address, port): + @abstractmethod + def _init(self) -> None: + raise NotImplementedError() + + @abstractmethod + def _start_reading(self) -> None: + raise NotImplementedError() + + @abstractmethod + def _send(self, data: bytes) -> None: + raise NotImplementedError() + + def connect_tcp(self, address: str, port: int) -> None: """Connect to tcp/ip `address`:`port`. Delegated to `_connect_tcp`.""" info('Connecting to TCP address: %s:%d', address, port) self._connect_tcp(address, port) - def connect_socket(self, path): + @abstractmethod + def _connect_tcp(self, address: str, port: int) -> None: + raise NotImplementedError() + + def connect_socket(self, path: str) -> None: """Connect to socket at `path`. Delegated to `_connect_socket`.""" info('Connecting to %s', path) self._connect_socket(path) - def connect_stdio(self): + @abstractmethod + def _connect_socket(self, path: str) -> None: + raise NotImplementedError() + + def connect_stdio(self) -> None: """Connect using stdin/stdout. Delegated to `_connect_stdio`.""" info('Preparing stdin/stdout for streaming data') self._connect_stdio() + @abstractmethod + def _connect_stdio(self) -> None: + raise NotImplementedError() + def connect_child(self, argv): """Connect a new Nvim instance. Delegated to `_connect_child`.""" info('Spawning a new nvim instance') self._connect_child(argv) - def send(self, data): + @abstractmethod + def _connect_child(self, argv: List[str]) -> None: + raise NotImplementedError() + + def send(self, data: bytes) -> None: """Queue `data` for sending to Nvim.""" debug("Sending '%s'", data) self._send(data) @@ -148,17 +190,25 @@ def run(self, data_cb): signal.signal(signal.SIGINT, default_int_handler) self._on_data = None - def stop(self): + def stop(self) -> None: """Stop the event loop.""" self._stop() debug('Stopped event loop') - def close(self): + @abstractmethod + def _stop(self) -> None: + raise NotImplementedError() + + def close(self) -> None: """Stop the event loop.""" self._close() debug('Closed event loop') - def _on_signal(self, signum): + @abstractmethod + def _close(self) -> None: + raise NotImplementedError() + + def _on_signal(self, signum: signal.Signals) -> None: msg = 'Received {}'.format(self._signames[signum]) debug(msg) if signum == signal.SIGINT and self._transport_type == 'stdio': @@ -166,16 +216,16 @@ def _on_signal(self, signum): # child process. In that case, we don't want to be killed by # ctrl+C return - cls = Exception + cls: Type[BaseException] = Exception if signum == signal.SIGINT: cls = KeyboardInterrupt self._error = cls(msg) self.stop() - def _on_error(self, error): + def _on_error(self, error: str) -> None: debug(error) self._error = OSError(error) self.stop() - def _on_interrupt(self): + def _on_interrupt(self) -> None: self.stop() diff --git a/pynvim/msgpack_rpc/session.py b/pynvim/msgpack_rpc/session.py index dfa22614..e3a8a77e 100644 --- a/pynvim/msgpack_rpc/session.py +++ b/pynvim/msgpack_rpc/session.py @@ -1,18 +1,47 @@ """Synchronous msgpack-rpc session layer.""" import logging +import sys import threading from collections import deque from traceback import format_exc +from typing import (Any, AnyStr, Callable, Deque, List, NamedTuple, Optional, Sequence, + Tuple, Union, cast) import greenlet from pynvim.compat import check_async +from pynvim.msgpack_rpc.async_session import AsyncSession + +if sys.version_info < (3, 8): + from typing_extensions import Literal +else: + from typing import Literal logger = logging.getLogger(__name__) error, debug, info, warn = (logger.error, logger.debug, logger.info, logger.warning,) +class Request(NamedTuple): + """A request from Nvim.""" + + type: Literal['request'] + name: str + args: List[Any] + response: Any + + +class Notification(NamedTuple): + """A notification from Nvim.""" + + type: Literal['notification'] + name: str + args: List[Any] + + +Message = Union[Request, Notification] + + class Session(object): """Msgpack-rpc session layer that uses coroutines for a synchronous API. @@ -22,17 +51,22 @@ class Session(object): from Nvim with a synchronous API. """ - def __init__(self, async_session): + def __init__(self, async_session: AsyncSession): """Wrap `async_session` on a synchronous msgpack-rpc interface.""" self._async_session = async_session - self._request_cb = self._notification_cb = None - self._pending_messages = deque() + self._request_cb: Optional[Callable[[str, List[Any]], None]] = None + self._notification_cb: Optional[Callable[[str, List[Any]], None]] = None + self._pending_messages: Deque[Message] = deque() self._is_running = False - self._setup_exception = None + self._setup_exception: Optional[Exception] = None self.loop = async_session.loop - self._loop_thread = None + self._loop_thread: Optional[threading.Thread] = None + self.error_wrapper: Callable[[Tuple[int, str]], Exception] = \ + lambda e: Exception(e[1]) - def threadsafe_call(self, fn, *args, **kwargs): + def threadsafe_call( + self, fn: Callable[..., Any], *args: Any, **kwargs: Any + ) -> None: """Wrapper around `AsyncSession.threadsafe_call`.""" def handler(): try: @@ -47,7 +81,7 @@ def greenlet_wrapper(): self._async_session.threadsafe_call(greenlet_wrapper) - def next_message(self): + def next_message(self) -> Optional[Message]: """Block until a message(request or notification) is available. If any messages were previously enqueued, return the first in queue. @@ -61,8 +95,9 @@ def next_message(self): self._enqueue_notification_and_stop) if self._pending_messages: return self._pending_messages.popleft() + return None - def request(self, method, *args, **kwargs): + def request(self, method: AnyStr, *args: Any, **kwargs: Any) -> Any: """Send a msgpack-rpc request and block until as response is received. If the event loop is running, this method must have been called by a @@ -102,7 +137,10 @@ def request(self, method, *args, **kwargs): raise self.error_wrapper(err) return rv - def run(self, request_cb, notification_cb, setup_cb=None): + def run(self, + request_cb: Callable[[str, List[Any]], None], + notification_cb: Callable[[str, List[Any]], None], + setup_cb: Callable[[], None] = None) -> None: """Run the event loop to receive requests and notifications from Nvim. Like `AsyncSession.run()`, but `request_cb` and `notification_cb` are @@ -114,9 +152,9 @@ def run(self, request_cb, notification_cb, setup_cb=None): self._setup_exception = None self._loop_thread = threading.current_thread() - def on_setup(): + def on_setup() -> None: try: - setup_cb() + setup_cb() # type: ignore[misc] except Exception as e: self._setup_exception = e self.stop() @@ -127,7 +165,9 @@ def on_setup(): gr.switch() if self._setup_exception: - error('Setup error: {}'.format(self._setup_exception)) + error( # type: ignore[unreachable] + 'Setup error: {}'.format(self._setup_exception) + ) raise self._setup_exception # Process all pending requests and notifications @@ -143,15 +183,17 @@ def on_setup(): if self._setup_exception: raise self._setup_exception - def stop(self): + def stop(self) -> None: """Stop the event loop.""" self._async_session.stop() - def close(self): + def close(self) -> None: """Close the event loop.""" self._async_session.close() - def _yielding_request(self, method, args): + def _yielding_request( + self, method: AnyStr, args: Sequence[Any] + ) -> Tuple[Tuple[int, str], Any]: gr = greenlet.getcurrent() parent = gr.parent @@ -163,7 +205,9 @@ def response_cb(err, rv): debug('yielding from greenlet %s to wait for response', gr) return parent.switch() - def _blocking_request(self, method, args): + def _blocking_request( + self, method: AnyStr, args: Sequence[Any] + ) -> Tuple[Tuple[int, str], Any]: result = [] def response_cb(err, rv): @@ -173,21 +217,23 @@ def response_cb(err, rv): self._async_session.request(method, args, response_cb) self._async_session.run(self._enqueue_request, self._enqueue_notification) - return result + return cast(Tuple[Tuple[int, str], Any], tuple(result)) - def _enqueue_request_and_stop(self, name, args, response): + def _enqueue_request_and_stop( + self, name: str, args: List[Any], response: Any + ) -> None: self._enqueue_request(name, args, response) self.stop() - def _enqueue_notification_and_stop(self, name, args): + def _enqueue_notification_and_stop(self, name: str, args: List[Any]) -> None: self._enqueue_notification(name, args) self.stop() - def _enqueue_request(self, name, args, response): - self._pending_messages.append(('request', name, args, response,)) + def _enqueue_request(self, name: str, args: List[Any], response: Any) -> None: + self._pending_messages.append(Request('request', name, args, response,)) - def _enqueue_notification(self, name, args): - self._pending_messages.append(('notification', name, args,)) + def _enqueue_notification(self, name: str, args: List[Any]) -> None: + self._pending_messages.append(Notification('notification', name, args,)) def _on_request(self, name, args, response): def handler(): diff --git a/pynvim/plugin/__init__.py b/pynvim/plugin/__init__.py index 9365438b..cb4ba41e 100644 --- a/pynvim/plugin/__init__.py +++ b/pynvim/plugin/__init__.py @@ -2,7 +2,7 @@ from pynvim.plugin.decorators import (autocmd, command, decode, encoding, function, plugin, rpc_export, shutdown_hook) -from pynvim.plugin.host import Host +from pynvim.plugin.host import Host # type: ignore[attr-defined] __all__ = ('Host', 'plugin', 'rpc_export', 'command', 'autocmd', diff --git a/pynvim/plugin/decorators.py b/pynvim/plugin/decorators.py index d0c45bc5..f4200293 100644 --- a/pynvim/plugin/decorators.py +++ b/pynvim/plugin/decorators.py @@ -2,22 +2,32 @@ import inspect import logging +import sys +from typing import Any, Callable, Dict, TypeVar, Union from pynvim.compat import unicode_errors_default +if sys.version_info < (3, 8): + from typing_extensions import Literal +else: + from typing import Literal + logger = logging.getLogger(__name__) debug, info, warn = (logger.debug, logger.info, logger.warning,) __all__ = ('plugin', 'rpc_export', 'command', 'autocmd', 'function', 'encoding', 'decode', 'shutdown_hook') +T = TypeVar('T') +F = TypeVar('F', bound=Callable[..., Any]) + -def plugin(cls): +def plugin(cls: T) -> T: """Tag a class as a plugin. This decorator is required to make the class methods discoverable by the plugin_load method of the host. """ - cls._nvim_plugin = True + cls._nvim_plugin = True # type: ignore[attr-defined] # the _nvim_bind attribute is set to True by default, meaning that # decorated functions have a bound Nvim instance as first argument. # For methods in a plugin-decorated class this is not required, because @@ -28,27 +38,39 @@ def plugin(cls): return cls -def rpc_export(rpc_method_name, sync=False): +def rpc_export(rpc_method_name: str, sync: bool = False) -> Callable[[F], F]: """Export a function or plugin method as a msgpack-rpc request handler.""" - def dec(f): - f._nvim_rpc_method_name = rpc_method_name - f._nvim_rpc_sync = sync - f._nvim_bind = True - f._nvim_prefix_plugin_path = False + def dec(f: F) -> F: + f._nvim_rpc_method_name = rpc_method_name # type: ignore[attr-defined] + f._nvim_rpc_sync = sync # type: ignore[attr-defined] + f._nvim_bind = True # type: ignore[attr-defined] + f._nvim_prefix_plugin_path = False # type: ignore[attr-defined] return f return dec -def command(name, nargs=0, complete=None, range=None, count=None, bang=False, - register=False, sync=False, allow_nested=False, eval=None): +def command( + name: str, + nargs: Union[str, int] = 0, + complete: str = None, + range: Union[str, int] = None, + count: int = None, + bang: bool = False, + register: bool = False, + sync: bool = False, + allow_nested: bool = False, + eval: str = None +) -> Callable[[F], F]: """Tag a function or plugin method as a Nvim command handler.""" - def dec(f): - f._nvim_rpc_method_name = 'command:{}'.format(name) - f._nvim_rpc_sync = sync - f._nvim_bind = True - f._nvim_prefix_plugin_path = True + def dec(f: F) -> F: + f._nvim_rpc_method_name = ( # type: ignore[attr-defined] + 'command:{}'.format(name) + ) + f._nvim_rpc_sync = sync # type: ignore[attr-defined] + f._nvim_bind = True # type: ignore[attr-defined] + f._nvim_prefix_plugin_path = True # type: ignore[attr-defined] - opts = {} + opts: Dict[str, Any] = {} if range is not None: opts['range'] = '' if range is True else str(range) @@ -71,11 +93,11 @@ def dec(f): opts['eval'] = eval if not sync and allow_nested: - rpc_sync = "urgent" + rpc_sync: Union[bool, Literal['urgent']] = "urgent" else: rpc_sync = sync - f._nvim_rpc_spec = { + f._nvim_rpc_spec = { # type: ignore[attr-defined] 'type': 'command', 'name': name, 'sync': rpc_sync, @@ -85,13 +107,21 @@ def dec(f): return dec -def autocmd(name, pattern='*', sync=False, allow_nested=False, eval=None): +def autocmd( + name: str, + pattern: str = '*', + sync: bool = False, + allow_nested: bool = False, + eval: str = None +) -> Callable[[F], F]: """Tag a function or plugin method as a Nvim autocommand handler.""" - def dec(f): - f._nvim_rpc_method_name = 'autocmd:{}:{}'.format(name, pattern) - f._nvim_rpc_sync = sync - f._nvim_bind = True - f._nvim_prefix_plugin_path = True + def dec(f: F) -> F: + f._nvim_rpc_method_name = ( # type: ignore[attr-defined] + 'autocmd:{}:{}'.format(name, pattern) + ) + f._nvim_rpc_sync = sync # type: ignore[attr-defined] + f._nvim_bind = True # type: ignore[attr-defined] + f._nvim_prefix_plugin_path = True # type: ignore[attr-defined] opts = { 'pattern': pattern @@ -101,11 +131,11 @@ def dec(f): opts['eval'] = eval if not sync and allow_nested: - rpc_sync = "urgent" + rpc_sync: Union[bool, Literal['urgent']] = "urgent" else: rpc_sync = sync - f._nvim_rpc_spec = { + f._nvim_rpc_spec = { # type: ignore[attr-defined] 'type': 'autocmd', 'name': name, 'sync': rpc_sync, @@ -115,13 +145,21 @@ def dec(f): return dec -def function(name, range=False, sync=False, allow_nested=False, eval=None): +def function( + name: str, + range: Union[bool, str, int] = False, + sync: bool = False, + allow_nested: bool = False, + eval: str = None +) -> Callable[[F], F]: """Tag a function or plugin method as a Nvim function handler.""" - def dec(f): - f._nvim_rpc_method_name = 'function:{}'.format(name) - f._nvim_rpc_sync = sync - f._nvim_bind = True - f._nvim_prefix_plugin_path = True + def dec(f: F) -> F: + f._nvim_rpc_method_name = ( # type: ignore[attr-defined] + 'function:{}'.format(name) + ) + f._nvim_rpc_sync = sync # type: ignore[attr-defined] + f._nvim_bind = True # type: ignore[attr-defined] + f._nvim_prefix_plugin_path = True # type: ignore[attr-defined] opts = {} @@ -132,11 +170,11 @@ def dec(f): opts['eval'] = eval if not sync and allow_nested: - rpc_sync = "urgent" + rpc_sync: Union[bool, Literal['urgent']] = "urgent" else: rpc_sync = sync - f._nvim_rpc_spec = { + f._nvim_rpc_spec = { # type: ignore[attr-defined] 'type': 'function', 'name': name, 'sync': rpc_sync, @@ -146,27 +184,27 @@ def dec(f): return dec -def shutdown_hook(f): +def shutdown_hook(f: F) -> F: """Tag a function or method as a shutdown hook.""" - f._nvim_shutdown_hook = True - f._nvim_bind = True + f._nvim_shutdown_hook = True # type: ignore[attr-defined] + f._nvim_bind = True # type: ignore[attr-defined] return f -def decode(mode=unicode_errors_default): +def decode(mode: str = unicode_errors_default) -> Callable[[F], F]: """Configure automatic encoding/decoding of strings.""" - def dec(f): - f._nvim_decode = mode + def dec(f: F) -> F: + f._nvim_decode = mode # type: ignore[attr-defined] return f return dec -def encoding(encoding=True): +def encoding(encoding: Union[bool, str] = True) -> Callable[[F], F]: """DEPRECATED: use pynvim.decode().""" if isinstance(encoding, str): encoding = True - def dec(f): - f._nvim_decode = encoding + def dec(f: F) -> F: + f._nvim_decode = encoding # type: ignore[attr-defined] return f return dec diff --git a/pynvim/plugin/host.py b/pynvim/plugin/host.py index a80224be..a4220e1c 100644 --- a/pynvim/plugin/host.py +++ b/pynvim/plugin/host.py @@ -1,3 +1,4 @@ +# type: ignore[no-untyped-def] """Implements a Nvim host for python plugins.""" import imp import inspect @@ -7,14 +8,15 @@ import re from functools import partial from traceback import format_exc +from typing import Any, Sequence -from pynvim.api import decode_if_bytes, walk +from pynvim.api import Nvim, decode_if_bytes, walk from pynvim.compat import find_module from pynvim.msgpack_rpc import ErrorResponse from pynvim.plugin import script_host from pynvim.util import format_exc_skip, get_client_info -__all__ = ('Host') +__all__ = ('Host',) logger = logging.getLogger(__name__) error, debug, info, warn = (logger.error, logger.debug, logger.info, @@ -31,7 +33,7 @@ class Host(object): requests/notifications to the appropriate handlers. """ - def __init__(self, nvim): + def __init__(self, nvim: Nvim): """Set handlers for plugin_load/plugin_unload.""" self.nvim = nvim self._specs = {} @@ -48,11 +50,11 @@ def __init__(self, nvim): self._decode_default = True - def _on_async_err(self, msg): + def _on_async_err(self, msg: str) -> None: # uncaught python exception self.nvim.err_write(msg, async_=True) - def _on_error_event(self, kind, msg): + def _on_error_event(self, kind: Any, msg: str) -> None: # error from nvim due to async request # like nvim.command(..., async_=True) errmsg = "{}: Async request caused an error:\n{}\n".format( @@ -67,7 +69,7 @@ def start(self, plugins): lambda: self._load(plugins), err_cb=self._on_async_err) - def shutdown(self): + def shutdown(self) -> None: """Shutdown the host.""" self._unload() self.nvim.stop_loop() @@ -109,7 +111,7 @@ def _wrap_function(self, fn, sync, decode, nvim_bind, name, *args): .format(name, args, format_exc_skip(1))) self._on_async_err(msg + "\n") - def _on_request(self, name, args): + def _on_request(self, name: str, args: Sequence[Any]) -> None: """Handle a msgpack-rpc request.""" name = decode_if_bytes(name) handler = self._request_handlers.get(name, None) @@ -123,7 +125,7 @@ def _on_request(self, name, args): debug("request handler for '%s %s' returns: %s", name, args, rv) return rv - def _on_notification(self, name, args): + def _on_notification(self, name: str, args: Sequence[Any]) -> None: """Handle a msgpack-rpc notification.""" name = decode_if_bytes(name) handler = self._notification_handlers.get(name, None) @@ -145,7 +147,7 @@ def _missing_handler_error(self, name, kind): msg = msg + "\n" + loader_error return msg - def _load(self, plugins): + def _load(self, plugins: Sequence[str]) -> None: has_script = False for path in plugins: err = None @@ -179,7 +181,7 @@ def _load(self, plugins): self.name = info[0] self.nvim.api.set_client_info(*info, async_=True) - def _unload(self): + def _unload(self) -> None: for path, plugin in self._loaded.items(): handlers = plugin['handlers'] for handler in handlers: diff --git a/pynvim/plugin/script_host.py b/pynvim/plugin/script_host.py index c0f033f3..9eb6803b 100644 --- a/pynvim/plugin/script_host.py +++ b/pynvim/plugin/script_host.py @@ -1,3 +1,4 @@ +# type: ignore """Legacy python/python3-vim emulation.""" import imp import io diff --git a/pynvim/util.py b/pynvim/util.py index da86f716..61072590 100644 --- a/pynvim/util.py +++ b/pynvim/util.py @@ -3,22 +3,29 @@ import sys from traceback import format_exception from types import SimpleNamespace +from typing import Any, Dict, Tuple, TypeVar -def format_exc_skip(skip, limit=None): +def format_exc_skip(skip: int, limit: int = None) -> str: """Like traceback.format_exc but allow skipping the first frames.""" etype, val, tb = sys.exc_info() - for i in range(skip): - tb = tb.tb_next - return (''.join(format_exception(etype, val, tb, limit))).rstrip() + for _ in range(skip): + if tb is not None: + tb = tb.tb_next + return ("".join(format_exception(etype, val, tb, limit))).rstrip() -def get_client_info(kind, type_, method_spec): +T1 = TypeVar("T1") +T2 = TypeVar("T2") + + +def get_client_info( + kind: str, type_: T1, method_spec: T2 +) -> Tuple[str, Dict[str, Any], T1, T2, Dict[str, str]]: """Returns a tuple describing the client.""" name = "python{}-{}".format(sys.version_info[0], kind) - attributes = {"license": "Apache v2", - "website": "github.com/neovim/pynvim"} + attributes = {"license": "Apache v2", "website": "github.com/neovim/pynvim"} return (name, VERSION.__dict__, type_, method_spec, attributes) -VERSION = SimpleNamespace(major=0, minor=4, patch=3, prerelease='') +VERSION = SimpleNamespace(major=0, minor=4, patch=3, prerelease="") diff --git a/setup.cfg b/setup.cfg index bda52dc7..0d680197 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,3 +14,18 @@ known_first_party = pynvim [tool:pytest] testpaths = test timeout = 10 + +[mypy] +disallow_untyped_calls = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +disallow_untyped_decorators = true +ignore_missing_imports = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_unreachable = true +strict_equality = true + +[mypy-pynvim.msgpack_rpc.*] +disallow_untyped_calls = false +disallow_untyped_defs = false diff --git a/setup.py b/setup.py index 49aa0d81..f52b56cb 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,9 @@ # pypy already includes an implementation of the greenlet module install_requires.append('greenlet') +if sys.version_info < (3, 8): + install_requires.append('typing-extensions') + setup(name='pynvim', version='0.4.3', description='Python client to neovim', diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/conftest.py b/test/conftest.py index 85ef33f6..b63c53c2 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -9,7 +9,7 @@ @pytest.fixture -def vim(): +def vim() -> pynvim.Nvim: child_argv = os.environ.get('NVIM_CHILD_ARGV') listen_address = os.environ.get('NVIM_LISTEN_ADDRESS') if child_argv is None and listen_address is None: @@ -18,7 +18,7 @@ def vim(): if child_argv is not None: editor = pynvim.attach('child', argv=json.loads(child_argv)) else: - assert listen_address is None or listen_address != '' + assert listen_address is not None and listen_address != '' editor = pynvim.attach('socket', path=listen_address) return editor diff --git a/test/test_buffer.py b/test/test_buffer.py index c26cebba..eeb4f8f5 100644 --- a/test/test_buffer.py +++ b/test/test_buffer.py @@ -2,14 +2,14 @@ import pytest -from pynvim.api import NvimError +from pynvim.api import Nvim, NvimError -def test_repr(vim): +def test_repr(vim: Nvim) -> None: assert repr(vim.current.buffer) == "" -def test_get_length(vim): +def test_get_length(vim: Nvim) -> None: assert len(vim.current.buffer) == 1 vim.current.buffer.append('line') assert len(vim.current.buffer) == 2 @@ -23,7 +23,7 @@ def test_get_length(vim): assert len(vim.current.buffer) == 1 -def test_get_set_del_line(vim): +def test_get_set_del_line(vim: Nvim) -> None: assert vim.current.buffer[0] == '' vim.current.buffer[0] = 'line1' assert vim.current.buffer[0] == 'line1' @@ -42,7 +42,7 @@ def test_get_set_del_line(vim): assert len(vim.current.buffer) == 1 -def test_get_set_del_slice(vim): +def test_get_set_del_slice(vim: Nvim) -> None: assert vim.current.buffer[:] == [''] # Replace buffer vim.current.buffer[:] = ['a', 'b', 'c'] @@ -72,7 +72,7 @@ def test_get_set_del_slice(vim): assert vim.current.buffer[:] == ['c'] -def test_vars(vim): +def test_vars(vim: Nvim) -> None: vim.current.buffer.vars['python'] = [1, 2, {'3': 1}] assert vim.current.buffer.vars['python'] == [1, 2, {'3': 1}] assert vim.eval('b:python') == [1, 2, {'3': 1}] @@ -89,7 +89,7 @@ def test_vars(vim): assert vim.current.buffer.vars.get('python', 'default') == 'default' -def test_api(vim): +def test_api(vim: Nvim) -> None: vim.current.buffer.api.set_var('myvar', 'thetext') assert vim.current.buffer.api.get_var('myvar') == 'thetext' assert vim.eval('b:myvar') == 'thetext' @@ -98,7 +98,7 @@ def test_api(vim): assert vim.current.buffer[:] == ['alpha', 'beta'] -def test_options(vim): +def test_options(vim: Nvim) -> None: assert vim.current.buffer.options['shiftwidth'] == 8 vim.current.buffer.options['shiftwidth'] = 4 assert vim.current.buffer.options['shiftwidth'] == 4 @@ -113,7 +113,7 @@ def test_options(vim): assert excinfo.value.args == ("Invalid option name: 'doesnotexist'",) -def test_number(vim): +def test_number(vim: Nvim) -> None: curnum = vim.current.buffer.number vim.command('new') assert vim.current.buffer.number == curnum + 1 @@ -121,7 +121,7 @@ def test_number(vim): assert vim.current.buffer.number == curnum + 2 -def test_name(vim): +def test_name(vim: Nvim) -> None: vim.command('new') assert vim.current.buffer.name == '' new_name = vim.eval('resolve(tempname())') @@ -132,7 +132,7 @@ def test_name(vim): os.unlink(new_name) -def test_valid(vim): +def test_valid(vim: Nvim) -> None: vim.command('new') buffer = vim.current.buffer assert buffer.valid @@ -140,7 +140,7 @@ def test_valid(vim): assert not buffer.valid -def test_append(vim): +def test_append(vim: Nvim) -> None: vim.current.buffer.append('a') assert vim.current.buffer[:] == ['', 'a'] vim.current.buffer.append('b', 0) @@ -153,14 +153,14 @@ def test_append(vim): assert vim.current.buffer[:] == ['b', '', 'c', 'd', 'a', 'c', 'd', 'bytes'] -def test_mark(vim): +def test_mark(vim: Nvim) -> None: vim.current.buffer.append(['a', 'bit of', 'text']) - vim.current.window.cursor = [3, 4] + vim.current.window.cursor = (3, 4) vim.command('mark V') - assert vim.current.buffer.mark('V') == [3, 0] + assert vim.current.buffer.mark('V') == (3, 0) -def test_invalid_utf8(vim): +def test_invalid_utf8(vim: Nvim) -> None: vim.command('normal "=printf("%c", 0xFF)\np') assert vim.eval("char2nr(getline(1))") == 0xFF assert vim.current.buffer[:] == ['\udcff'] @@ -170,7 +170,7 @@ def test_invalid_utf8(vim): assert vim.current.buffer[:] == ['\udcffx'] -def test_get_exceptions(vim): +def test_get_exceptions(vim: Nvim) -> None: with pytest.raises(KeyError) as excinfo: vim.current.buffer.options['invalid-option'] @@ -178,7 +178,7 @@ def test_get_exceptions(vim): assert excinfo.value.args == ("Invalid option name: 'invalid-option'",) -def test_set_items_for_range(vim): +def test_set_items_for_range(vim: Nvim) -> None: vim.current.buffer[:] = ['a', 'b', 'c', 'd', 'e'] r = vim.current.buffer.range(1, 3) r[1:3] = ['foo'] * 3 @@ -187,14 +187,14 @@ def test_set_items_for_range(vim): # NB: we can't easily test the effect of this. But at least run the lua # function sync, so we know it runs without runtime error with simple args. -def test_update_highlights(vim): +def test_update_highlights(vim: Nvim) -> None: vim.current.buffer[:] = ['a', 'b', 'c'] src_id = vim.new_highlight_source() vim.current.buffer.update_highlights( - src_id, [["Comment", 0, 0, -1], ("String", 1, 0, 1)], clear=True, async_=False + src_id, [("Comment", 0, 0, -1), ("String", 1, 0, 1)], clear=True, async_=False ) -def test_buffer_inequality(vim): +def test_buffer_inequality(vim: Nvim) -> None: b = vim.current.buffer assert not (b != vim.current.buffer) diff --git a/test/test_client_rpc.py b/test/test_client_rpc.py index 3dcb6b56..9fb19b25 100644 --- a/test/test_client_rpc.py +++ b/test/test_client_rpc.py @@ -1,17 +1,20 @@ # -*- coding: utf-8 -*- import time +from typing import List +from pynvim.api import Nvim -def test_call_and_reply(vim): + +def test_call_and_reply(vim: Nvim) -> None: cid = vim.channel_id - def setup_cb(): + def setup_cb() -> None: cmd = 'let g:result = rpcrequest(%d, "client-call", 1, 2, 3)' % cid vim.command(cmd) assert vim.vars['result'] == [4, 5, 6] vim.stop_loop() - def request_cb(name, args): + def request_cb(name: str, args: List[int]) -> List[int]: assert name == 'client-call' assert args == [1, 2, 3] return [4, 5, 6] @@ -19,25 +22,25 @@ def request_cb(name, args): vim.run_loop(request_cb, None, setup_cb) -def test_call_api_before_reply(vim): +def test_call_api_before_reply(vim: Nvim) -> None: cid = vim.channel_id - def setup_cb(): + def setup_cb() -> None: cmd = 'let g:result = rpcrequest(%d, "client-call2", 1, 2, 3)' % cid vim.command(cmd) assert vim.vars['result'] == [7, 8, 9] vim.stop_loop() - def request_cb(name, args): + def request_cb(name: str, args: List[int]) -> List[int]: vim.command('let g:result2 = [7, 8, 9]') return vim.vars['result2'] vim.run_loop(request_cb, None, setup_cb) -def test_async_call(vim): +def test_async_call(vim: Nvim) -> None: - def request_cb(name, args): + def request_cb(name: str, args: List[int]) -> None: if name == "test-event": vim.vars['result'] = 17 vim.stop_loop() @@ -53,10 +56,10 @@ def request_cb(name, args): assert vim.vars['result'] == 17 -def test_recursion(vim): +def test_recursion(vim: Nvim) -> None: cid = vim.channel_id - def setup_cb(): + def setup_cb() -> None: vim.vars['result1'] = 0 vim.vars['result2'] = 0 vim.vars['result3'] = 0 @@ -69,7 +72,7 @@ def setup_cb(): assert vim.vars['result4'] == 32 vim.stop_loop() - def request_cb(name, args): + def request_cb(name: str, args: List[int]) -> int: n = args[0] n *= 2 if n <= 16: diff --git a/test/test_concurrency.py b/test/test_concurrency.py index d74f482a..a9e711db 100644 --- a/test/test_concurrency.py +++ b/test/test_concurrency.py @@ -1,16 +1,24 @@ from threading import Timer +from typing import List +from pynvim.api import Nvim -def test_interrupt_from_another_thread(vim): + +def test_interrupt_from_another_thread(vim: Nvim) -> None: timer = Timer(0.5, lambda: vim.async_call(lambda: vim.stop_loop())) timer.start() assert vim.next_message() is None -def test_exception_in_threadsafe_call(vim): +def test_exception_in_threadsafe_call(vim: Nvim) -> None: # an exception in a threadsafe_call shouldn't crash the entire host - msgs = [] - vim.async_call(lambda: [vim.eval("3"), undefined_variable]) # noqa: F821 + msgs: List[str] = [] + vim.async_call( + lambda: [ + vim.eval("3"), + undefined_variable # type: ignore[name-defined] # noqa: F821 + ] + ) timer = Timer(0.5, lambda: vim.async_call(lambda: vim.stop_loop())) timer.start() vim.run_loop(None, None, err_cb=msgs.append) diff --git a/test/test_decorators.py b/test/test_decorators.py index 6e578a1d..d9725cc8 100644 --- a/test/test_decorators.py +++ b/test/test_decorators.py @@ -1,8 +1,9 @@ +# type: ignore[attr-defined] from pynvim.plugin.decorators import command -def test_command_count(): - def function(): +def test_command_count() -> None: + def function() -> None: """A dummy function to decorate.""" return diff --git a/test/test_events.py b/test/test_events.py index 3d403315..2a27a89b 100644 --- a/test/test_events.py +++ b/test/test_events.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- +from pynvim.api import Nvim -def test_receiving_events(vim): +def test_receiving_events(vim: Nvim) -> None: vim.command('call rpcnotify(%d, "test-event", 1, 2, 3)' % vim.channel_id) event = vim.next_message() assert event[1] == 'test-event' @@ -14,7 +15,7 @@ def test_receiving_events(vim): assert event[2] == [vim.current.buffer.number] -def test_sending_notify(vim): +def test_sending_notify(vim: Nvim) -> None: # notify after notify vim.command("let g:test = 3", async_=True) cmd = 'call rpcnotify(%d, "test-event", g:test)' % vim.channel_id @@ -28,14 +29,14 @@ def test_sending_notify(vim): assert vim.eval('g:data') == 'xyz' -def test_async_error(vim): +def test_async_error(vim: Nvim) -> None: # Invoke a bogus Ex command via notify (async). vim.command("lolwut", async_=True) event = vim.next_message() assert event[1] == 'nvim_error_event' -def test_broadcast(vim): +def test_broadcast(vim: Nvim) -> None: vim.subscribe('event2') vim.command('call rpcnotify(0, "event1", 1, 2, 3)') vim.command('call rpcnotify(0, "event2", 4, 5, 6)') diff --git a/test/test_host.py b/test/test_host.py index b21a398b..e4cb389c 100644 --- a/test/test_host.py +++ b/test/test_host.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +# type: ignore from pynvim.plugin.host import Host, host_method_spec from pynvim.plugin.script_host import ScriptHost diff --git a/test/test_logging.py b/test/test_logging.py index 4df938be..4becf306 100644 --- a/test/test_logging.py +++ b/test/test_logging.py @@ -1,8 +1,9 @@ import os import sys +from typing import Any -def test_setup_logging(monkeypatch, tmpdir, caplog): +def test_setup_logging(monkeypatch: Any, tmpdir: str, caplog: Any) -> None: from pynvim import setup_logging major_version = sys.version_info[0] @@ -10,7 +11,7 @@ def test_setup_logging(monkeypatch, tmpdir, caplog): setup_logging('name1') assert caplog.messages == [] - def get_expected_logfile(prefix, name): + def get_expected_logfile(prefix: str, name: str) -> str: return '{}_py{}_{}'.format(prefix, major_version, name) prefix = tmpdir.join('testlog1') diff --git a/test/test_tabpage.py b/test/test_tabpage.py index 51c64a5b..f700510d 100644 --- a/test/test_tabpage.py +++ b/test/test_tabpage.py @@ -1,7 +1,9 @@ import pytest +from pynvim.api import Nvim -def test_windows(vim): + +def test_windows(vim: Nvim) -> None: vim.command('tabnew') vim.command('vsplit') assert list(vim.tabpages[0].windows) == [vim.windows[0]] @@ -11,7 +13,7 @@ def test_windows(vim): assert vim.tabpages[1].window == vim.windows[2] -def test_vars(vim): +def test_vars(vim: Nvim) -> None: vim.current.tabpage.vars['python'] = [1, 2, {'3': 1}] assert vim.current.tabpage.vars['python'] == [1, 2, {'3': 1}] assert vim.eval('t:python') == [1, 2, {'3': 1}] @@ -28,7 +30,7 @@ def test_vars(vim): assert vim.current.tabpage.vars.get('python', 'default') == 'default' -def test_valid(vim): +def test_valid(vim: Nvim) -> None: vim.command('tabnew') tabpage = vim.tabpages[1] assert tabpage.valid @@ -36,7 +38,7 @@ def test_valid(vim): assert not tabpage.valid -def test_number(vim): +def test_number(vim: Nvim) -> None: curnum = vim.current.tabpage.number vim.command('tabnew') assert vim.current.tabpage.number == curnum + 1 @@ -44,5 +46,5 @@ def test_number(vim): assert vim.current.tabpage.number == curnum + 2 -def test_repr(vim): +def test_repr(vim: Nvim) -> None: assert repr(vim.current.tabpage) == "" diff --git a/test/test_vim.py b/test/test_vim.py index 6eb9877c..c3095928 100644 --- a/test/test_vim.py +++ b/test/test_vim.py @@ -1,11 +1,14 @@ # -*- coding: utf-8 -*- import os import tempfile +from typing import Any import pytest +from pynvim.api import Nvim -def source(vim, code): + +def source(vim: Nvim, code: str) -> None: fd, fname = tempfile.mkstemp() with os.fdopen(fd, 'w') as f: f.write(code) @@ -13,11 +16,11 @@ def source(vim, code): os.unlink(fname) -def test_clientinfo(vim): +def test_clientinfo(vim: Nvim) -> None: assert 'remote' == vim.api.get_chan_info(vim.channel_id)['client']['type'] -def test_command(vim): +def test_command(vim: Nvim) -> None: fname = tempfile.mkstemp()[1] vim.command('new') vim.command('edit {}'.format(fname)) @@ -31,17 +34,17 @@ def test_command(vim): os.unlink(fname) -def test_command_output(vim): +def test_command_output(vim: Nvim) -> None: assert vim.command_output('echo "test"') == 'test' -def test_command_error(vim): +def test_command_error(vim: Nvim) -> None: with pytest.raises(vim.error) as excinfo: vim.current.window.cursor = -1, -1 assert excinfo.value.args == ('Cursor position outside buffer',) -def test_eval(vim): +def test_eval(vim: Nvim) -> None: vim.command('let g:v1 = "a"') vim.command('let g:v2 = [1, 2, {"v3": 3}]') g = vim.eval('g:') @@ -49,7 +52,7 @@ def test_eval(vim): assert g['v2'] == [1, 2, {'v3': 3}] -def test_call(vim): +def test_call(vim: Nvim) -> None: assert vim.funcs.join(['first', 'last'], ', ') == 'first, last' source(vim, """ function! Testfun(a,b) @@ -59,19 +62,19 @@ def test_call(vim): assert vim.funcs.Testfun(3, 'alpha') == '3:alpha' -def test_api(vim): +def test_api(vim: Nvim) -> None: vim.api.command('let g:var = 3') assert vim.api.eval('g:var') == 3 -def test_strwidth(vim): +def test_strwidth(vim: Nvim) -> None: assert vim.strwidth('abc') == 3 # 6 + (neovim) # 19 * 2 (each japanese character occupies two cells) assert vim.strwidth('neovimのデザインかなりまともなのになってる。') == 44 -def test_chdir(vim): +def test_chdir(vim: Nvim) -> None: pwd = vim.eval('getcwd()') root = os.path.abspath(os.sep) # We can chdir to '/' on Windows, but then the pwd will be the root drive @@ -81,13 +84,13 @@ def test_chdir(vim): assert vim.eval('getcwd()') == pwd -def test_current_line(vim): +def test_current_line(vim: Nvim) -> None: assert vim.current.line == '' vim.current.line = 'abc' assert vim.current.line == 'abc' -def test_current_line_delete(vim): +def test_current_line_delete(vim: Nvim) -> None: vim.current.buffer[:] = ['one', 'two'] assert len(vim.current.buffer[:]) == 2 del vim.current.line @@ -96,10 +99,10 @@ def test_current_line_delete(vim): assert len(vim.current.buffer[:]) == 1 and not vim.current.buffer[0] -def test_vars(vim): +def test_vars(vim: Nvim) -> None: vim.vars['python'] = [1, 2, {'3': 1}] - assert vim.vars['python'], [1, 2 == {'3': 1}] - assert vim.eval('g:python'), [1, 2 == {'3': 1}] + assert vim.vars['python'] == [1, 2, {'3': 1}] + assert vim.eval('g:python') == [1, 2, {'3': 1}] assert vim.vars.get('python') == [1, 2, {'3': 1}] del vim.vars['python'] @@ -113,19 +116,19 @@ def test_vars(vim): assert vim.vars.get('python', 'default') == 'default' -def test_options(vim): +def test_options(vim: Nvim) -> None: assert vim.options['background'] == 'dark' vim.options['background'] = 'light' assert vim.options['background'] == 'light' -def test_local_options(vim): +def test_local_options(vim: Nvim) -> None: assert vim.windows[0].options['foldmethod'] == 'manual' vim.windows[0].options['foldmethod'] = 'syntax' assert vim.windows[0].options['foldmethod'] == 'syntax' -def test_buffers(vim): +def test_buffers(vim: Nvim) -> None: buffers = [] # Number of elements @@ -145,13 +148,13 @@ def test_buffers(vim): # Membership test assert buffers[0] in vim.buffers assert buffers[1] in vim.buffers - assert {} not in vim.buffers + assert {} not in vim.buffers # type: ignore[operator] # Iteration assert buffers == list(vim.buffers) -def test_windows(vim): +def test_windows(vim: Nvim) -> None: assert len(vim.windows) == 1 assert vim.windows[0] == vim.current.window vim.command('vsplit') @@ -162,7 +165,7 @@ def test_windows(vim): assert vim.windows[1] == vim.current.window -def test_tabpages(vim): +def test_tabpages(vim: Nvim) -> None: assert len(vim.tabpages) == 1 assert vim.tabpages[0] == vim.current.tabpage vim.command('tabnew') @@ -181,7 +184,7 @@ def test_tabpages(vim): assert vim.windows[1] == vim.current.window -def test_hash(vim): +def test_hash(vim: Nvim) -> None: d = {} d[vim.current.buffer] = "alpha" assert d[vim.current.buffer] == 'alpha' @@ -194,7 +197,7 @@ def test_hash(vim): assert d[vim.current.buffer] == 'beta' -def test_cwd(vim, tmpdir): +def test_cwd(vim: Nvim, tmpdir: Any) -> None: vim.command('python3 import os') cwd_before = vim.command_output('python3 print(os.getcwd())') @@ -227,7 +230,7 @@ def test_cwd(vim, tmpdir): """ -def test_lua(vim): +def test_lua(vim: Nvim) -> None: assert vim.exec_lua(lua_code, 7) == "eggspam" assert vim.lua.pynvimtest_func(3) == 10 lua_module = vim.lua.pynvimtest diff --git a/test/test_window.py b/test/test_window.py index bc48aeab..7a36d9e0 100644 --- a/test/test_window.py +++ b/test/test_window.py @@ -1,7 +1,9 @@ import pytest +from pynvim.api import Nvim -def test_buffer(vim): + +def test_buffer(vim: Nvim) -> None: assert vim.current.buffer == vim.windows[0].buffer vim.command('new') vim.current.window = vim.windows[1] @@ -9,17 +11,17 @@ def test_buffer(vim): assert vim.windows[0].buffer != vim.windows[1].buffer -def test_cursor(vim): - assert vim.current.window.cursor == [1, 0] +def test_cursor(vim: Nvim) -> None: + assert vim.current.window.cursor == (1, 0) vim.command('normal ityping\033o some text') assert vim.current.buffer[:] == ['typing', ' some text'] - assert vim.current.window.cursor == [2, 10] - vim.current.window.cursor = [2, 6] + assert vim.current.window.cursor == (2, 10) + vim.current.window.cursor = (2, 6) vim.command('normal i dumb') assert vim.current.buffer[:] == ['typing', ' some dumb text'] -def test_height(vim): +def test_height(vim: Nvim) -> None: vim.command('vsplit') assert vim.windows[1].height == vim.windows[0].height vim.current.window = vim.windows[1] @@ -29,7 +31,7 @@ def test_height(vim): assert vim.windows[1].height == 2 -def test_width(vim): +def test_width(vim: Nvim) -> None: vim.command('split') assert vim.windows[1].width == vim.windows[0].width vim.current.window = vim.windows[1] @@ -39,7 +41,7 @@ def test_width(vim): assert vim.windows[1].width == 2 -def test_vars(vim): +def test_vars(vim: Nvim) -> None: vim.current.window.vars['python'] = [1, 2, {'3': 1}] assert vim.current.window.vars['python'] == [1, 2, {'3': 1}] assert vim.eval('w:python') == [1, 2, {'3': 1}] @@ -56,7 +58,7 @@ def test_vars(vim): assert vim.current.window.vars.get('python', 'default') == 'default' -def test_options(vim): +def test_options(vim: Nvim) -> None: vim.current.window.options['colorcolumn'] = '4,3' assert vim.current.window.options['colorcolumn'] == '4,3' # global-local option @@ -69,7 +71,7 @@ def test_options(vim): assert excinfo.value.args == ("Invalid option name: 'doesnotexist'",) -def test_position(vim): +def test_position(vim: Nvim) -> None: height = vim.windows[0].height width = vim.windows[0].width vim.command('split') @@ -83,7 +85,7 @@ def test_position(vim): assert vim.windows[2].col == 0 -def test_tabpage(vim): +def test_tabpage(vim: Nvim) -> None: vim.command('tabnew') vim.command('vsplit') assert vim.windows[0].tabpage == vim.tabpages[0] @@ -91,7 +93,7 @@ def test_tabpage(vim): assert vim.windows[2].tabpage == vim.tabpages[1] -def test_valid(vim): +def test_valid(vim: Nvim) -> None: vim.command('split') window = vim.windows[1] vim.current.window = window @@ -100,7 +102,7 @@ def test_valid(vim): assert not window.valid -def test_number(vim): +def test_number(vim: Nvim) -> None: curnum = vim.current.window.number vim.command('bot split') assert vim.current.window.number == curnum + 1 @@ -108,7 +110,7 @@ def test_number(vim): assert vim.current.window.number == curnum + 2 -def test_handle(vim): +def test_handle(vim: Nvim) -> None: hnd1 = vim.current.window.handle vim.command('bot split') hnd2 = vim.current.window.handle @@ -120,5 +122,5 @@ def test_handle(vim): assert vim.current.window.handle == hnd1 -def test_repr(vim): +def test_repr(vim: Nvim) -> None: assert repr(vim.current.window) == "" diff --git a/tox.ini b/tox.ini index c7ca1ee8..16d4fc27 100644 --- a/tox.ini +++ b/tox.ini @@ -17,11 +17,16 @@ commands = [testenv:checkqa] deps = + mypy flake8 flake8-import-order flake8-docstrings pep8-naming -commands = flake8 {posargs:pynvim test} + msgpack-types +ignore_errors = true +commands = + flake8 {posargs:pynvim test} + mypy --show-error-codes {posargs:pynvim test} [testenv:docs] deps =