Skip to content

Commit

Permalink
Create a dedicated base Client class
Browse files Browse the repository at this point in the history
  • Loading branch information
alcarney committed Apr 4, 2023
1 parent bc59090 commit baf0e70
Show file tree
Hide file tree
Showing 4 changed files with 299 additions and 97 deletions.
174 changes: 174 additions & 0 deletions pygls/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
############################################################################
# Copyright(c) Open Law Library. All rights reserved. #
# See ThirdPartyNotices.txt in the project root for additional notices. #
# #
# Licensed under the Apache License, Version 2.0 (the "License") #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http: // www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
############################################################################
import asyncio
import logging
import re
import sys
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from typing import Callable
from typing import Optional
from typing import TextIO
from typing import Type
from typing import TypeVar

from pygls.io import StdOutTransportAdapter
from pygls.protocol import JsonRPCProtocol

logger = logging.getLogger(__name__)

F = TypeVar("F", bound=Callable)


async def aio_readline(executor, stop_event, rfile, proxy):
"""Reads data from stdin in separate thread (asynchronously)."""

CONTENT_LENGTH_PATTERN = re.compile(rb"^Content-Length: (\d+)\r\n$")
loop = asyncio.get_running_loop()

# Initialize message buffer
message = []
content_length = 0

while not stop_event.is_set() and not rfile.closed:
# Read a header line
header = await loop.run_in_executor(executor, rfile.readline)
if not header:
break
message.append(header)

# Extract content length if possible
if not content_length:
match = CONTENT_LENGTH_PATTERN.fullmatch(header)
if match:
content_length = int(match.group(1))
logger.debug("Content length: %s", content_length)

# Check if all headers have been read (as indicated by an empty line \r\n)
if content_length and not header.strip():

# Read body
body = await loop.run_in_executor(executor, rfile.read, content_length)
if not body:
break
message.append(body)

# Pass message to language server protocol
proxy(b"".join(message))

# Reset the buffer
message = []
content_length = 0


class Client:
"""Base class for a JSON-RPC client.
Args:
protocol_cls(Protocol): Protocol implementation that must be derived
from `asyncio.Protocol`
converter_factory: Factory function to use when constructing a cattrs converter.
loop(AbstractEventLoop): asyncio event loop
max_workers(int, optional): Number of workers for `ThreadPool` and
`ThreadPoolExecutor`
Attributes:
_max_workers(int): Number of workers for thread pool executor
_server(Server): Server object which can be used to stop the process
_stop_event(Event): Event used for stopping `aio_readline`
_thread_pool(ThreadPool): Thread pool for executing methods decorated
with `@ls.thread()` - lazy instantiated
_thread_pool_executor(ThreadPoolExecutor): Thread pool executor
passed to `run_in_executor`
- lazy instantiated
"""

def __init__(
self,
protocol_cls: Type[JsonRPCProtocol],
converter_factory,
loop=None,
max_workers=2,
):
self._max_workers = max_workers
self._stop_event = threading.Event()
self._thread_pool_executor = None
self.loop = loop or asyncio.new_event_loop()

self.protocol = protocol_cls(self, converter_factory())

def close(self):
"""Shutdown server."""
logger.info("Closing the client")

self._stop_event.set()

if self._thread_pool_executor:
self._thread_pool_executor.shutdown()

def start_io(self, stdin: Optional[TextIO] = None, stdout: Optional[TextIO] = None):
"""Starts IO server."""
logger.info("Starting IO client")

transport = StdOutTransportAdapter(
stdin or sys.stdin.buffer, stdout or sys.stdout.buffer
)
self.protocol.connection_made(transport)

try:
asyncio.run(
aio_readline(
self.thread_pool_executor,
self._stop_event,
stdin or sys.stdin.buffer,
self.protocol.data_received,
)
)
except BrokenPipeError:
logger.error("Connection to the server lost! Closing the client.")
except (KeyboardInterrupt, SystemExit):
pass
finally:
self.close()

@property
def thread_pool_executor(self) -> ThreadPoolExecutor:
"""Returns thread pool instance (lazy initialization)."""
if not self._thread_pool_executor:
self._thread_pool_executor = ThreadPoolExecutor(
max_workers=self._max_workers
)

return self._thread_pool_executor

def feature(
self,
feature_name: str,
options: Optional[Any] = None,
) -> Callable[[F], F]:
"""Decorator used to register LSP features.
Example:
@client.feature('textDocument/publishDiagnostics')
def diagnostics(ls, params: PublishDiagnosticParams):
self.diagnostics[params.uri] = params.diagnostics
"""
return self.protocol.fm.feature(feature_name, options)
117 changes: 117 additions & 0 deletions pygls/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
############################################################################
# Copyright(c) Open Law Library. All rights reserved. #
# See ThirdPartyNotices.txt in the project root for additional notices. #
# #
# Licensed under the Apache License, Version 2.0 (the "License") #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http: // www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
############################################################################
import asyncio
import logging
import re
from typing import Any

logger = logging.getLogger(__name__)


async def aio_readline(loop, executor, stop_event, rfile, proxy):
"""Reads data from stdin in separate thread (asynchronously)."""

CONTENT_LENGTH_PATTERN = re.compile(rb"^Content-Length: (\d+)\r\n$")

# Initialize message buffer
message = []
content_length = 0

while not stop_event.is_set() and not rfile.closed:
# Read a header line
header = await loop.run_in_executor(executor, rfile.readline)
if not header:
break
message.append(header)

# Extract content length if possible
if not content_length:
match = CONTENT_LENGTH_PATTERN.fullmatch(header)
if match:
content_length = int(match.group(1))
logger.debug("Content length: %s", content_length)

# Check if all headers have been read (as indicated by an empty line \r\n)
if content_length and not header.strip():

# Read body
body = await loop.run_in_executor(executor, rfile.read, content_length)
if not body:
break
message.append(body)

# Pass message to language server protocol
proxy(b"".join(message))

# Reset the buffer
message = []
content_length = 0


class StdOutTransportAdapter:
"""Protocol adapter which overrides write method.
Write method sends data to stdout.
"""

def __init__(self, rfile, wfile):
self.rfile = rfile
self.wfile = wfile

def close(self):
self.rfile.close()
self.wfile.close()

def write(self, data):
self.wfile.write(data)
self.wfile.flush()


class PyodideTransportAdapter:
"""Protocol adapter which overrides write method.
Write method sends data to stdout.
"""

def __init__(self, wfile):
self.wfile = wfile

def close(self):
self.wfile.close()

def write(self, data):
self.wfile.write(data)
self.wfile.flush()


class WebSocketTransportAdapter:
"""Protocol adapter which calls write method.
Write method sends data via the WebSocket interface.
"""

def __init__(self, ws, loop):
self._ws = ws
self._loop = loop

def close(self) -> None:
"""Stop the WebSocket server."""
self._ws.close()

def write(self, data: Any) -> None:
"""Create a task to write specified data into a WebSocket."""
asyncio.ensure_future(self._ws.send(data))
Loading

0 comments on commit baf0e70

Please sign in to comment.