From 25150a1dd8544fb8064d4f6fc18432cc2a22561b Mon Sep 17 00:00:00 2001 From: Thomas Buckley-Houston Date: Thu, 14 Sep 2023 19:05:44 -0500 Subject: [PATCH] chore: fix mypy types --- pygls/client.py | 5 ++++- pygls/protocol.py | 24 +++++++++++++++++------- pygls/server.py | 20 ++++++++++++++------ pygls/uris.py | 14 +++++++------- pygls/workspace/document.py | 2 +- pygls/workspace/position.py | 4 ++-- tests/conftest.py | 1 - 7 files changed, 45 insertions(+), 25 deletions(-) diff --git a/pygls/client.py b/pygls/client.py index b2ea96bd..577f05e0 100644 --- a/pygls/client.py +++ b/pygls/client.py @@ -79,7 +79,10 @@ def __init__( protocol_cls: Type[JsonRPCProtocol] = JsonRPCProtocol, converter_factory: Callable[[], Converter] = default_converter, ): - self.protocol = protocol_cls(self, converter_factory()) + # Strictly speaking `JsonRPCProtocol` wants a `LanguageServer`, not a + # `JsonRPCClient`. However there similar enough for our purposes, which is + # that this client will mostly be used in testing contexts. + self.protocol = protocol_cls(self, converter_factory()) # type: ignore self._server: Optional[asyncio.subprocess.Process] = None self._stop_event = Event() diff --git a/pygls/protocol.py b/pygls/protocol.py index 8cafa4ca..1733c2fc 100644 --- a/pygls/protocol.py +++ b/pygls/protocol.py @@ -28,10 +28,20 @@ from concurrent.futures import Future from functools import lru_cache, partial from itertools import zip_longest -from typing import Any, Callable, List, Optional, Type, TypeVar, Union, TYPE_CHECKING +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Type, + TypeVar, + Union, + TYPE_CHECKING, +) if TYPE_CHECKING: - from pygls.server import Server + from pygls.server import LanguageServer import attrs @@ -243,19 +253,19 @@ class JsonRPCProtocol(asyncio.Protocol): VERSION = "2.0" - def __init__(self, server: Server, converter): + def __init__(self, server: LanguageServer, converter): self._server = server self._converter = converter self._shutdown = False # Book keeping for in-flight requests - self._request_futures = {} - self._result_types = {} + self._request_futures: Dict[str, Future[Any]] = {} + self._result_types: Dict[str, Any] = {} self.fm = FeatureManager(server) - self.transport = None - self._message_buf = [] + self.transport: Optional[asyncio.BaseTransport] = None + self._message_buf: List[bytes] = [] self._send_only_body = False diff --git a/pygls/server.py b/pygls/server.py index be661fde..c0bb0a97 100644 --- a/pygls/server.py +++ b/pygls/server.py @@ -21,11 +21,16 @@ import sys from concurrent.futures import Future, ThreadPoolExecutor from threading import Event -from typing import Any, Callable, List, Optional, TextIO, TypeVar, Union +from typing import Any, Callable, List, Optional, TextIO, Type, TypeVar, Union from pygls import IS_PYODIDE from pygls.lsp import ConfigCallbackType, ShowDocumentCallbackType -from pygls.exceptions import PyglsError, JsonRpcException, FeatureRequestError +from pygls.exceptions import ( + JsonRpcInternalError, + PyglsError, + JsonRpcException, + FeatureRequestError, +) from lsprotocol.types import ( ClientCapabilities, Diagnostic, @@ -51,6 +56,8 @@ F = TypeVar("F", bound=Callable) +ServerErrors = Union[PyglsError, JsonRpcException, Type[JsonRpcInternalError]] + async def aio_readline(loop, executor, stop_event, rfile, proxy): """Reads data from stdin in separate thread (asynchronously).""" @@ -370,6 +377,7 @@ def __init__( self.name = name self.version = version + self.process_id: Optional[Union[int, None]] = None super().__init__(protocol_cls, converter_factory, loop, max_workers) def apply_edit( @@ -501,7 +509,9 @@ def show_message_log(self, message, msg_type=MessageType.Log) -> None: self.lsp.show_message_log(message, msg_type) def _report_server_error( - self, error: Exception, source: Union[PyglsError, JsonRpcException] + self, + error: Exception, + source: ServerErrors, ): # Prevent recursive error reporting try: @@ -509,9 +519,7 @@ def _report_server_error( except Exception: logger.warning("Failed to report error to client") - def report_server_error( - self, error: Exception, source: Union[PyglsError, JsonRpcException] - ): + def report_server_error(self, error: Exception, source: ServerErrors): """ Sends error to the client for displaying. diff --git a/pygls/uris.py b/pygls/uris.py index 69a6ce3e..8c40f70b 100644 --- a/pygls/uris.py +++ b/pygls/uris.py @@ -21,7 +21,7 @@ https://github.com/Microsoft/vscode-uri/blob/e59cab84f5df6265aed18ae5f43552d3eef13bb9/lib/index.ts """ -from typing import Tuple +from typing import Optional, Tuple import re from urllib import parse @@ -118,12 +118,12 @@ def uri_scheme(uri: str): # TODO: Use `URLParts` type def uri_with( uri: str, - scheme: str | None = None, - netloc: str | None = None, - path: str | None = None, - params: str | None = None, - query: str | None = None, - fragment: str | None = None, + scheme: Optional[str] = None, + netloc: Optional[str] = None, + path: Optional[str] = None, + params: Optional[str] = None, + query: Optional[str] = None, + fragment: Optional[str] = None, ): """ Return a URI with the given part(s) replaced. diff --git a/pygls/workspace/document.py b/pygls/workspace/document.py index 96093f61..c0a16c93 100644 --- a/pygls/workspace/document.py +++ b/pygls/workspace/document.py @@ -54,7 +54,7 @@ def __init__( raise Exception("`path` cannot be None") self.path = path self.language_id = language_id - self.filename: str | None = os.path.basename(self.path) + self.filename: Optional[str] = os.path.basename(self.path) self._local = local self._source = source diff --git a/pygls/workspace/position.py b/pygls/workspace/position.py index f54071a4..0cb40e9b 100644 --- a/pygls/workspace/position.py +++ b/pygls/workspace/position.py @@ -17,7 +17,7 @@ # limitations under the License. # ############################################################################ import logging -from typing import List, Optional +from typing import List, Optional, Union from lsprotocol import types @@ -29,7 +29,7 @@ class Position: def __init__( self, encoding: Optional[ - types.PositionEncodingKind | str + Union[types.PositionEncodingKind, str] ] = types.PositionEncodingKind.Utf16, ): self.encoding = encoding diff --git a/tests/conftest.py b/tests/conftest.py index 7ea5d434..7d0b50ae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,6 @@ from pygls import uris, IS_PYODIDE, IS_WIN from pygls.feature_manager import FeatureManager -from pygls.workspace.workspace import Document from pygls.workspace.workspace import Workspace from .ls_setup import (