Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ type API struct {
func NewAPI(init *APIInit) *API {
api := &API{
session: project.NewSession(&project.SessionInit{
Logger: init.Logger,
FS: init.FS,
Options: init.SessionOptions,
BackgroundCtx: context.Background(),
Logger: init.Logger,
FS: init.FS,
Options: init.SessionOptions,
}),
projects: make(map[Handle[project.Project]]tspath.Path),
files: make(handleMap[ast.SourceFile]),
Expand Down
3 changes: 1 addition & 2 deletions internal/fourslash/fourslash.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,10 @@ func NewFourslash(t *testing.T, capabilities *lsproto.ClientCapabilities, conten
fsFromMap := vfstest.FromMap(testfs, true /*useCaseSensitiveFileNames*/)
fs := bundled.WrapFS(fsFromMap)

var err strings.Builder
server := lsp.NewServer(&lsp.ServerOptions{
In: inputReader,
Out: outputWriter,
Err: &err,
Err: io.Discard,

Cwd: "/",
FS: fs,
Expand Down
8 changes: 7 additions & 1 deletion internal/lsp/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@ func (l *logger) sendLogMessage(msgType lsproto.MessageType, message string) {
Type: msgType,
Message: message,
})
l.server.outgoingQueue <- notification.Message()

select {
case l.server.outgoingQueue <- notification.Message():
// sent
case <-l.server.backgroundCtx.Done():
fmt.Fprintln(l.server.stderr, message)
}
}

func (l *logger) Log(msg ...any) {
Expand Down
91 changes: 55 additions & 36 deletions internal/lsp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ var (
)

type Server struct {
r Reader
w Writer
r Reader
w Writer
backgroundCtx context.Context

stderr io.Writer

Expand Down Expand Up @@ -246,8 +247,7 @@ func (s *Server) RefreshDiagnostics(ctx context.Context) error {
// PublishDiagnostics implements project.Client.
func (s *Server) PublishDiagnostics(ctx context.Context, params *lsproto.PublishDiagnosticsParams) error {
notification := lsproto.TextDocumentPublishDiagnosticsInfo.NewNotificationMessage(params)
s.outgoingQueue <- notification.Message()
return nil
return s.send(ctx, notification.Message())
}

func (s *Server) RefreshInlayHints(ctx context.Context) error {
Expand Down Expand Up @@ -300,6 +300,7 @@ func (s *Server) RequestConfiguration(ctx context.Context) (*lsutil.UserPreferen

func (s *Server) Run(ctx context.Context) error {
g, ctx := errgroup.WithContext(ctx)
s.backgroundCtx = ctx
g.Go(func() error { return s.dispatchLoop(ctx) })
g.Go(func() error { return s.writeLoop(ctx) })

Expand Down Expand Up @@ -329,7 +330,9 @@ func (s *Server) readLoop(ctx context.Context) error {
msg, err := s.read()
if err != nil {
if errors.Is(err, lsproto.ErrorCodeInvalidRequest) {
s.sendError(nil, err)
if err := s.sendError(ctx, nil, err); err != nil {
return err
}
continue
}
return err
Expand All @@ -342,9 +345,13 @@ func (s *Server) readLoop(ctx context.Context) error {
if err != nil {
return err
}
s.sendResult(req.ID, resp)
if err := s.sendResult(ctx, req.ID, resp); err != nil {
return err
}
} else {
s.sendError(req.ID, lsproto.ErrorCodeServerNotInitialized)
if err := s.sendError(ctx, req.ID, lsproto.ErrorCodeServerNotInitialized); err != nil {
return err
}
}
continue
}
Expand Down Expand Up @@ -406,11 +413,11 @@ func (s *Server) dispatchLoop(ctx context.Context) error {
handle := func() {
if err := s.handleRequestOrNotification(requestCtx, req); err != nil {
if errors.Is(err, context.Canceled) {
s.sendError(req.ID, lsproto.ErrorCodeRequestCancelled)
_ = s.sendError(requestCtx, req.ID, lsproto.ErrorCodeRequestCancelled)
} else if errors.Is(err, io.EOF) {
lspExit()
} else {
s.sendError(req.ID, err)
_ = s.sendError(requestCtx, req.ID, err)
}
}

Expand Down Expand Up @@ -452,16 +459,21 @@ func sendClientRequest[Req, Resp any](ctx context.Context, s *Server, info lspro
s.pendingServerRequests[*id] = responseChan
s.pendingServerRequestsMu.Unlock()

s.outgoingQueue <- req.Message()

select {
case <-ctx.Done():
defer func() {
s.pendingServerRequestsMu.Lock()
defer s.pendingServerRequestsMu.Unlock()
if respChan, ok := s.pendingServerRequests[*id]; ok {
close(respChan)
delete(s.pendingServerRequests, *id)
}
}()

if err := s.send(ctx, req.Message()); err != nil {
return *new(Resp), err
}

select {
case <-ctx.Done():
return *new(Resp), ctx.Err()
case resp := <-responseChan:
if resp.Error != nil {
Expand All @@ -471,20 +483,20 @@ func sendClientRequest[Req, Resp any](ctx context.Context, s *Server, info lspro
}
}

func (s *Server) sendResult(id *lsproto.ID, result any) {
s.sendResponse(&lsproto.ResponseMessage{
func (s *Server) sendResult(ctx context.Context, id *lsproto.ID, result any) error {
return s.sendResponse(ctx, &lsproto.ResponseMessage{
ID: id,
Result: result,
})
}

func (s *Server) sendError(id *lsproto.ID, err error) {
func (s *Server) sendError(ctx context.Context, id *lsproto.ID, err error) error {
code := lsproto.ErrorCodeInternalError
if errCode := lsproto.ErrorCode(0); errors.As(err, &errCode) {
code = errCode
}
// TODO(jakebailey): error data
s.sendResponse(&lsproto.ResponseMessage{
return s.sendResponse(ctx, &lsproto.ResponseMessage{
ID: id,
Error: &lsproto.ResponseError{
Code: int32(code),
Expand All @@ -493,8 +505,18 @@ func (s *Server) sendError(id *lsproto.ID, err error) {
})
}

func (s *Server) sendResponse(resp *lsproto.ResponseMessage) {
s.outgoingQueue <- resp.Message()
func (s *Server) sendResponse(ctx context.Context, resp *lsproto.ResponseMessage) error {
return s.send(ctx, resp.Message())
}

// send writes a message to the outgoing queue, respecting context cancellation.
func (s *Server) send(ctx context.Context, msg *lsproto.Message) error {
select {
case s.outgoingQueue <- msg:
return nil
case <-ctx.Done():
return ctx.Err()
}
}

func (s *Server) handleRequestOrNotification(ctx context.Context, req *lsproto.RequestMessage) error {
Expand All @@ -508,7 +530,7 @@ func (s *Server) handleRequestOrNotification(ctx context.Context, req *lsproto.R
}
s.logger.Warn("unknown method '", req.Method, "'")
if req.ID != nil {
s.sendError(req.ID, lsproto.ErrorCodeInvalidRequest)
return s.sendError(ctx, req.ID, lsproto.ErrorCodeInvalidRequest)
}
return nil
}
Expand Down Expand Up @@ -614,8 +636,7 @@ func registerRequestHandler[Req, Resp any](
if ctx.Err() != nil {
return ctx.Err()
}
s.sendResult(req.ID, resp)
return nil
return s.sendResult(ctx, req.ID, resp)
}
}

Expand All @@ -630,16 +651,15 @@ func registerLanguageServiceDocumentRequestHandler[Req lsproto.HasTextDocumentUR
if err != nil {
return err
}
defer s.recover(req)
defer s.recover(ctx, req)
resp, err := fn(s, ctx, ls, params)
if err != nil {
return err
}
if ctx.Err() != nil {
return ctx.Err()
}
s.sendResult(req.ID, resp)
return nil
return s.sendResult(ctx, req.ID, resp)
}
}

Expand All @@ -654,7 +674,7 @@ func registerLanguageServiceWithAutoImportsRequestHandler[Req lsproto.HasTextDoc
if err != nil {
return err
}
defer s.recover(req)
defer s.recover(ctx, req)
resp, err := fn(s, ctx, languageService, params)
if errors.Is(err, ls.ErrNeedsAutoImports) {
languageService, err = s.session.GetLanguageServiceWithAutoImports(ctx, params.TextDocumentURI())
Expand All @@ -675,8 +695,7 @@ func registerLanguageServiceWithAutoImportsRequestHandler[Req lsproto.HasTextDoc
if ctx.Err() != nil {
return ctx.Err()
}
s.sendResult(req.ID, resp)
return nil
return s.sendResult(ctx, req.ID, resp)
}
}

Expand All @@ -696,13 +715,12 @@ func registerMultiProjectReferenceRequestHandler[Req lsproto.HasTextDocumentPosi
if err != nil {
return err
}
defer s.recover(req)
defer s.recover(ctx, req)
resp, err := fn(defaultLs, ctx, params, orchestrator)
if err != nil {
return err
}
s.sendResult(req.ID, resp)
return nil
return s.sendResult(ctx, req.ID, resp)
}
}

Expand Down Expand Up @@ -750,12 +768,12 @@ func (s *Server) getLanguageServiceAndCrossProjectOrchestrator(ctx context.Conte
return defaultLs, orchestrator, err
}

func (s *Server) recover(req *lsproto.RequestMessage) {
func (s *Server) recover(ctx context.Context, req *lsproto.RequestMessage) {
if r := recover(); r != nil {
stack := debug.Stack()
s.logger.Errorf("panic handling request %s: %v\n%s", req.Method, r, string(stack))
if req.ID != nil {
s.sendError(req.ID, fmt.Errorf("%w: panic handling request %s: %v", lsproto.ErrorCodeInternalError, req.Method, r))
_ = s.sendError(ctx, req.ID, fmt.Errorf("%w: panic handling request %s: %v", lsproto.ErrorCodeInternalError, req.Method, r))
} else {
s.logger.Error("unhandled panic in notification", req.Method, r)
}
Expand Down Expand Up @@ -913,6 +931,7 @@ func (s *Server) handleInitialized(ctx context.Context, params *lsproto.Initiali
}

s.session = project.NewSession(&project.SessionInit{
BackgroundCtx: s.backgroundCtx,
Options: &project.SessionOptions{
CurrentDirectory: cwd,
DefaultLibraryPath: s.defaultLibraryPath,
Expand Down Expand Up @@ -1080,7 +1099,7 @@ func (s *Server) handleCompletionItemResolve(ctx context.Context, params *lsprot
if err != nil {
return nil, err
}
defer s.recover(reqMsg)
defer s.recover(ctx, reqMsg)
return languageService.ResolveCompletionItem(
ctx,
params,
Expand Down Expand Up @@ -1117,7 +1136,7 @@ func (s *Server) handleDocumentOnTypeFormat(ctx context.Context, ls *ls.Language

func (s *Server) handleWorkspaceSymbol(ctx context.Context, params *lsproto.WorkspaceSymbolParams, reqMsg *lsproto.RequestMessage) (lsproto.WorkspaceSymbolResponse, error) {
snapshot := s.session.GetSnapshotLoadingProjectTree(ctx, nil)
defer s.recover(reqMsg)
defer s.recover(ctx, reqMsg)

programs := core.Map(snapshot.ProjectCollection.Projects(), (*project.Project).GetProgram)
return ls.ProvideWorkspaceSymbols(
Expand Down Expand Up @@ -1171,7 +1190,7 @@ func (s *Server) handleCodeLensResolve(ctx context.Context, codeLens *lsproto.Co
// based on non-existent files and line maps from shortened files.
return codeLens, lsproto.ErrorCodeContentModified
}
defer s.recover(reqMsg)
defer s.recover(ctx, reqMsg)
return defaultLs.ResolveCodeLens(
ctx,
codeLens,
Expand Down
Loading