-
-
Notifications
You must be signed in to change notification settings - Fork 106
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create a dedicated base
Client
class
- Loading branch information
Showing
4 changed files
with
299 additions
and
97 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
############################################################################ | ||
# Copyright(c) Open Law Library. All rights reserved. # | ||
# See ThirdPartyNotices.txt in the project root for additional notices. # | ||
# # | ||
# Licensed under the Apache License, Version 2.0 (the "License") # | ||
# you may not use this file except in compliance with the License. # | ||
# You may obtain a copy of the License at # | ||
# # | ||
# http: // www.apache.org/licenses/LICENSE-2.0 # | ||
# # | ||
# Unless required by applicable law or agreed to in writing, software # | ||
# distributed under the License is distributed on an "AS IS" BASIS, # | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # | ||
# See the License for the specific language governing permissions and # | ||
# limitations under the License. # | ||
############################################################################ | ||
import asyncio | ||
import logging | ||
import re | ||
import sys | ||
import threading | ||
from concurrent.futures import ThreadPoolExecutor | ||
from typing import Any | ||
from typing import Callable | ||
from typing import Optional | ||
from typing import TextIO | ||
from typing import Type | ||
from typing import TypeVar | ||
|
||
from pygls.io import StdOutTransportAdapter | ||
from pygls.protocol import JsonRPCProtocol | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
F = TypeVar("F", bound=Callable) | ||
|
||
|
||
async def aio_readline(executor, stop_event, rfile, proxy): | ||
"""Reads data from stdin in separate thread (asynchronously).""" | ||
|
||
CONTENT_LENGTH_PATTERN = re.compile(rb"^Content-Length: (\d+)\r\n$") | ||
loop = asyncio.get_running_loop() | ||
|
||
# Initialize message buffer | ||
message = [] | ||
content_length = 0 | ||
|
||
while not stop_event.is_set() and not rfile.closed: | ||
# Read a header line | ||
header = await loop.run_in_executor(executor, rfile.readline) | ||
if not header: | ||
break | ||
message.append(header) | ||
|
||
# Extract content length if possible | ||
if not content_length: | ||
match = CONTENT_LENGTH_PATTERN.fullmatch(header) | ||
if match: | ||
content_length = int(match.group(1)) | ||
logger.debug("Content length: %s", content_length) | ||
|
||
# Check if all headers have been read (as indicated by an empty line \r\n) | ||
if content_length and not header.strip(): | ||
|
||
# Read body | ||
body = await loop.run_in_executor(executor, rfile.read, content_length) | ||
if not body: | ||
break | ||
message.append(body) | ||
|
||
# Pass message to language server protocol | ||
proxy(b"".join(message)) | ||
|
||
# Reset the buffer | ||
message = [] | ||
content_length = 0 | ||
|
||
|
||
class Client: | ||
"""Base class for a JSON-RPC client. | ||
Args: | ||
protocol_cls(Protocol): Protocol implementation that must be derived | ||
from `asyncio.Protocol` | ||
converter_factory: Factory function to use when constructing a cattrs converter. | ||
loop(AbstractEventLoop): asyncio event loop | ||
max_workers(int, optional): Number of workers for `ThreadPool` and | ||
`ThreadPoolExecutor` | ||
Attributes: | ||
_max_workers(int): Number of workers for thread pool executor | ||
_server(Server): Server object which can be used to stop the process | ||
_stop_event(Event): Event used for stopping `aio_readline` | ||
_thread_pool(ThreadPool): Thread pool for executing methods decorated | ||
with `@ls.thread()` - lazy instantiated | ||
_thread_pool_executor(ThreadPoolExecutor): Thread pool executor | ||
passed to `run_in_executor` | ||
- lazy instantiated | ||
""" | ||
|
||
def __init__( | ||
self, | ||
protocol_cls: Type[JsonRPCProtocol], | ||
converter_factory, | ||
loop=None, | ||
max_workers=2, | ||
): | ||
self._max_workers = max_workers | ||
self._stop_event = threading.Event() | ||
self._thread_pool_executor = None | ||
self.loop = loop or asyncio.new_event_loop() | ||
|
||
self.protocol = protocol_cls(self, converter_factory()) | ||
|
||
def close(self): | ||
"""Shutdown server.""" | ||
logger.info("Closing the client") | ||
|
||
self._stop_event.set() | ||
|
||
if self._thread_pool_executor: | ||
self._thread_pool_executor.shutdown() | ||
|
||
def start_io(self, stdin: Optional[TextIO] = None, stdout: Optional[TextIO] = None): | ||
"""Starts IO server.""" | ||
logger.info("Starting IO client") | ||
|
||
transport = StdOutTransportAdapter( | ||
stdin or sys.stdin.buffer, stdout or sys.stdout.buffer | ||
) | ||
self.protocol.connection_made(transport) | ||
|
||
try: | ||
asyncio.run( | ||
aio_readline( | ||
self.thread_pool_executor, | ||
self._stop_event, | ||
stdin or sys.stdin.buffer, | ||
self.protocol.data_received, | ||
) | ||
) | ||
except BrokenPipeError: | ||
logger.error("Connection to the server lost! Closing the client.") | ||
except (KeyboardInterrupt, SystemExit): | ||
pass | ||
finally: | ||
self.close() | ||
|
||
@property | ||
def thread_pool_executor(self) -> ThreadPoolExecutor: | ||
"""Returns thread pool instance (lazy initialization).""" | ||
if not self._thread_pool_executor: | ||
self._thread_pool_executor = ThreadPoolExecutor( | ||
max_workers=self._max_workers | ||
) | ||
|
||
return self._thread_pool_executor | ||
|
||
def feature( | ||
self, | ||
feature_name: str, | ||
options: Optional[Any] = None, | ||
) -> Callable[[F], F]: | ||
"""Decorator used to register LSP features. | ||
Example: | ||
@client.feature('textDocument/publishDiagnostics') | ||
def diagnostics(ls, params: PublishDiagnosticParams): | ||
self.diagnostics[params.uri] = params.diagnostics | ||
""" | ||
return self.protocol.fm.feature(feature_name, options) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
############################################################################ | ||
# Copyright(c) Open Law Library. All rights reserved. # | ||
# See ThirdPartyNotices.txt in the project root for additional notices. # | ||
# # | ||
# Licensed under the Apache License, Version 2.0 (the "License") # | ||
# you may not use this file except in compliance with the License. # | ||
# You may obtain a copy of the License at # | ||
# # | ||
# http: // www.apache.org/licenses/LICENSE-2.0 # | ||
# # | ||
# Unless required by applicable law or agreed to in writing, software # | ||
# distributed under the License is distributed on an "AS IS" BASIS, # | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # | ||
# See the License for the specific language governing permissions and # | ||
# limitations under the License. # | ||
############################################################################ | ||
import asyncio | ||
import logging | ||
import re | ||
from typing import Any | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
async def aio_readline(loop, executor, stop_event, rfile, proxy): | ||
"""Reads data from stdin in separate thread (asynchronously).""" | ||
|
||
CONTENT_LENGTH_PATTERN = re.compile(rb"^Content-Length: (\d+)\r\n$") | ||
|
||
# Initialize message buffer | ||
message = [] | ||
content_length = 0 | ||
|
||
while not stop_event.is_set() and not rfile.closed: | ||
# Read a header line | ||
header = await loop.run_in_executor(executor, rfile.readline) | ||
if not header: | ||
break | ||
message.append(header) | ||
|
||
# Extract content length if possible | ||
if not content_length: | ||
match = CONTENT_LENGTH_PATTERN.fullmatch(header) | ||
if match: | ||
content_length = int(match.group(1)) | ||
logger.debug("Content length: %s", content_length) | ||
|
||
# Check if all headers have been read (as indicated by an empty line \r\n) | ||
if content_length and not header.strip(): | ||
|
||
# Read body | ||
body = await loop.run_in_executor(executor, rfile.read, content_length) | ||
if not body: | ||
break | ||
message.append(body) | ||
|
||
# Pass message to language server protocol | ||
proxy(b"".join(message)) | ||
|
||
# Reset the buffer | ||
message = [] | ||
content_length = 0 | ||
|
||
|
||
class StdOutTransportAdapter: | ||
"""Protocol adapter which overrides write method. | ||
Write method sends data to stdout. | ||
""" | ||
|
||
def __init__(self, rfile, wfile): | ||
self.rfile = rfile | ||
self.wfile = wfile | ||
|
||
def close(self): | ||
self.rfile.close() | ||
self.wfile.close() | ||
|
||
def write(self, data): | ||
self.wfile.write(data) | ||
self.wfile.flush() | ||
|
||
|
||
class PyodideTransportAdapter: | ||
"""Protocol adapter which overrides write method. | ||
Write method sends data to stdout. | ||
""" | ||
|
||
def __init__(self, wfile): | ||
self.wfile = wfile | ||
|
||
def close(self): | ||
self.wfile.close() | ||
|
||
def write(self, data): | ||
self.wfile.write(data) | ||
self.wfile.flush() | ||
|
||
|
||
class WebSocketTransportAdapter: | ||
"""Protocol adapter which calls write method. | ||
Write method sends data via the WebSocket interface. | ||
""" | ||
|
||
def __init__(self, ws, loop): | ||
self._ws = ws | ||
self._loop = loop | ||
|
||
def close(self) -> None: | ||
"""Stop the WebSocket server.""" | ||
self._ws.close() | ||
|
||
def write(self, data: Any) -> None: | ||
"""Create a task to write specified data into a WebSocket.""" | ||
asyncio.ensure_future(self._ws.send(data)) |
Oops, something went wrong.