Skip to content

Commit e24a7bb

Browse files
authored
[mlir-lsp] Support outgoing requests (llvm#90078)
Add support for outgoing requests to `lsp::MessageHandler`. Much like `MessageHandler::outgoingNotification`, this allows for the message handler to send outgoing messages via its JSON transport, but in this case, those messages are requests, not notifications. Requests receive responses (also referred to as "replies" in `MLIRLspServerSupportLib`). These were previously unsupported, and `lsp::MessageHandler` would log an error each time it processed a JSON message that appeared to be a response (something with an "id" field, but no "method" field). However, the `outgoingRequest` method now handles response callbacks: an outgoing request with a given ID is set up such that a callback function is invoked when a response with that ID is received.
1 parent d47c498 commit e24a7bb

File tree

3 files changed

+103
-15
lines changed

3 files changed

+103
-15
lines changed

mlir/include/mlir/Tools/lsp-server-support/Transport.h

+41
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H
1616
#define MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H
1717

18+
#include "mlir/Support/DebugStringHelper.h"
1819
#include "mlir/Support/LLVM.h"
1920
#include "mlir/Support/LogicalResult.h"
2021
#include "mlir/Tools/lsp-server-support/Logging.h"
@@ -100,6 +101,18 @@ using Callback = llvm::unique_function<void(llvm::Expected<T>)>;
100101
template <typename T>
101102
using OutgoingNotification = llvm::unique_function<void(const T &)>;
102103

104+
/// An OutgoingRequest<T> is a function used for outgoing requests to send to
105+
/// the client.
106+
template <typename T>
107+
using OutgoingRequest =
108+
llvm::unique_function<void(const T &, llvm::json::Value id)>;
109+
110+
/// An `OutgoingRequestCallback` is invoked when an outgoing request to the
111+
/// client receives a response in turn. It is passed the original request's ID,
112+
/// as well as the result JSON.
113+
using OutgoingRequestCallback =
114+
std::function<void(llvm::json::Value, llvm::Expected<llvm::json::Value>)>;
115+
103116
/// A handler used to process the incoming transport messages.
104117
class MessageHandler {
105118
public:
@@ -170,6 +183,26 @@ class MessageHandler {
170183
};
171184
}
172185

186+
/// Create an OutgoingRequest function that, when called, sends a request with
187+
/// the given method via the transport. Should the outgoing request be
188+
/// met with a response, the response callback is invoked to handle that
189+
/// response.
190+
template <typename T>
191+
OutgoingRequest<T> outgoingRequest(llvm::StringLiteral method,
192+
OutgoingRequestCallback callback) {
193+
return [&, method, callback](const T &params, llvm::json::Value id) {
194+
{
195+
std::lock_guard<std::mutex> lock(responseHandlersMutex);
196+
responseHandlers.insert(
197+
{debugString(id), std::make_pair(method.str(), callback)});
198+
}
199+
200+
std::lock_guard<std::mutex> transportLock(transportOutputMutex);
201+
Logger::info("--> {0}({1})", method, id);
202+
transport.call(method, llvm::json::Value(params), id);
203+
};
204+
}
205+
173206
private:
174207
template <typename HandlerT>
175208
using HandlerMap = llvm::StringMap<llvm::unique_function<HandlerT>>;
@@ -178,6 +211,14 @@ class MessageHandler {
178211
HandlerMap<void(llvm::json::Value, Callback<llvm::json::Value>)>
179212
methodHandlers;
180213

214+
/// A pair of (1) the original request's method name, and (2) the callback
215+
/// function to be invoked for responses.
216+
using ResponseHandlerTy = std::pair<std::string, OutgoingRequestCallback>;
217+
/// A mapping from request/response ID to response handler.
218+
llvm::StringMap<ResponseHandlerTy> responseHandlers;
219+
/// Mutex to guard insertion into the response handler map.
220+
std::mutex responseHandlersMutex;
221+
181222
JSONTransport &transport;
182223

183224
/// Mutex to guard sending output messages to the transport.

mlir/lib/Tools/lsp-server-support/Transport.cpp

+23-15
Original file line numberDiff line numberDiff line change
@@ -117,21 +117,29 @@ bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params,
117117

118118
bool MessageHandler::onReply(llvm::json::Value id,
119119
llvm::Expected<llvm::json::Value> result) {
120-
// TODO: Add support for reply callbacks when support for outgoing messages is
121-
// added. For now, we just log an error on any replies received.
122-
Callback<llvm::json::Value> replyHandler =
123-
[&id](llvm::Expected<llvm::json::Value> result) {
124-
Logger::error(
125-
"received a reply with ID {0}, but there was no such call", id);
126-
if (!result)
127-
llvm::consumeError(result.takeError());
128-
};
129-
130-
// Log and run the reply handler.
131-
if (result)
132-
replyHandler(std::move(result));
133-
else
134-
replyHandler(result.takeError());
120+
// Find the response handler in the mapping. If it exists, move it out of the
121+
// mapping and erase it.
122+
ResponseHandlerTy responseHandler;
123+
{
124+
std::lock_guard<std::mutex> responseHandlersLock(responseHandlersMutex);
125+
auto it = responseHandlers.find(debugString(id));
126+
if (it != responseHandlers.end()) {
127+
responseHandler = std::move(it->second);
128+
responseHandlers.erase(it);
129+
}
130+
}
131+
132+
// If we found a response handler, invoke it. Otherwise, log an error.
133+
if (responseHandler.second) {
134+
Logger::info("--> reply:{0}({1})", responseHandler.first, id);
135+
responseHandler.second(std::move(id), std::move(result));
136+
} else {
137+
Logger::error(
138+
"received a reply with ID {0}, but there was no such outgoing request",
139+
id);
140+
if (!result)
141+
llvm::consumeError(result.takeError());
142+
}
135143
return true;
136144
}
137145

mlir/unittests/Tools/lsp-server-support/Transport.cpp

+39
Original file line numberDiff line numberDiff line change
@@ -131,4 +131,43 @@ TEST_F(TransportInputTest, OutgoingNotification) {
131131
notifyFn(CompletionList{});
132132
EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-notification\""));
133133
}
134+
135+
TEST_F(TransportInputTest, ResponseHandlerNotFound) {
136+
// Unhandled responses are only reported via error logging. As a result, this
137+
// test can't make any expectations -- but it prints the output anyway, by way
138+
// of demonstration.
139+
Logger::setLogLevel(Logger::Level::Error);
140+
writeInput("{\"jsonrpc\":\"2.0\",\"id\":81,\"result\":null}\n");
141+
runTransport();
142+
}
143+
144+
TEST_F(TransportInputTest, OutgoingRequest) {
145+
// Make some outgoing requests.
146+
int responseCallbackInvoked = 0;
147+
auto callFn = getMessageHandler().outgoingRequest<CompletionList>(
148+
"outgoing-request",
149+
[&responseCallbackInvoked](llvm::json::Value id,
150+
llvm::Expected<llvm::json::Value> value) {
151+
// Make expectations on the expected response.
152+
EXPECT_EQ(id, 83);
153+
ASSERT_TRUE((bool)value);
154+
EXPECT_EQ(debugString(*value), "{\"foo\":6}");
155+
responseCallbackInvoked += 1;
156+
llvm::outs() << "here!!!\n";
157+
});
158+
callFn({}, 82);
159+
callFn({}, 83);
160+
callFn({}, 84);
161+
EXPECT_THAT(getOutput(), HasSubstr("\"method\":\"outgoing-request\""));
162+
EXPECT_EQ(responseCallbackInvoked, 0);
163+
164+
// One of the requests receives a response. The message handler handles this
165+
// response by invoking the callback from above. Subsequent responses with the
166+
// same ID are ignored.
167+
writeInput("{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"foo\":6}}\n"
168+
"// -----\n"
169+
"{\"jsonrpc\":\"2.0\",\"id\":83,\"result\":{\"bar\":8}}\n");
170+
runTransport();
171+
EXPECT_EQ(responseCallbackInvoked, 1);
172+
}
134173
} // namespace

0 commit comments

Comments
 (0)