Skip to content

Commit cab7a0b

Browse files
authored
Concurrent LSP with cancelation (microsoft#869)
1 parent abd526f commit cab7a0b

35 files changed

+1341
-613
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ require (
77
github.com/go-json-experiment/json v0.0.0-20250223041408-d3c622f1b874
88
github.com/google/go-cmp v0.7.0
99
github.com/pkg/diff v0.0.0-20241224192749-4e6772a4315c
10+
golang.org/x/sync v0.11.0
1011
golang.org/x/sys v0.31.0
1112
gotest.tools/v3 v3.5.2
1213
)
1314

1415
require (
1516
github.com/matryer/moq v0.5.3 // indirect
1617
golang.org/x/mod v0.23.0 // indirect
17-
golang.org/x/sync v0.11.0 // indirect
1818
golang.org/x/tools v0.30.0 // indirect
1919
)
2020

internal/api/api.go

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package api
22

33
import (
4+
"context"
45
"encoding/json"
56
"errors"
67
"fmt"
@@ -129,7 +130,7 @@ func (api *API) IsWatchEnabled() bool {
129130
return false
130131
}
131132

132-
func (api *API) HandleRequest(id int, method string, payload []byte) ([]byte, error) {
133+
func (api *API) HandleRequest(ctx context.Context, method string, payload []byte) ([]byte, error) {
133134
params, err := unmarshalPayload(method, payload)
134135
if err != nil {
135136
return nil, err
@@ -155,27 +156,27 @@ func (api *API) HandleRequest(id int, method string, payload []byte) ([]byte, er
155156
return encodeJSON(api.LoadProject(params.(*LoadProjectParams).ConfigFileName))
156157
case MethodGetSymbolAtPosition:
157158
params := params.(*GetSymbolAtPositionParams)
158-
return encodeJSON(api.GetSymbolAtPosition(params.Project, params.FileName, int(params.Position)))
159+
return encodeJSON(api.GetSymbolAtPosition(ctx, params.Project, params.FileName, int(params.Position)))
159160
case MethodGetSymbolsAtPositions:
160161
params := params.(*GetSymbolsAtPositionsParams)
161162
return encodeJSON(core.TryMap(params.Positions, func(position uint32) (any, error) {
162-
return api.GetSymbolAtPosition(params.Project, params.FileName, int(position))
163+
return api.GetSymbolAtPosition(ctx, params.Project, params.FileName, int(position))
163164
}))
164165
case MethodGetSymbolAtLocation:
165166
params := params.(*GetSymbolAtLocationParams)
166-
return encodeJSON(api.GetSymbolAtLocation(params.Project, params.Location))
167+
return encodeJSON(api.GetSymbolAtLocation(ctx, params.Project, params.Location))
167168
case MethodGetSymbolsAtLocations:
168169
params := params.(*GetSymbolsAtLocationsParams)
169170
return encodeJSON(core.TryMap(params.Locations, func(location Handle[ast.Node]) (any, error) {
170-
return api.GetSymbolAtLocation(params.Project, location)
171+
return api.GetSymbolAtLocation(ctx, params.Project, location)
171172
}))
172173
case MethodGetTypeOfSymbol:
173174
params := params.(*GetTypeOfSymbolParams)
174-
return encodeJSON(api.GetTypeOfSymbol(params.Project, params.Symbol))
175+
return encodeJSON(api.GetTypeOfSymbol(ctx, params.Project, params.Symbol))
175176
case MethodGetTypesOfSymbols:
176177
params := params.(*GetTypesOfSymbolsParams)
177178
return encodeJSON(core.TryMap(params.Symbols, func(symbol Handle[ast.Symbol]) (any, error) {
178-
return api.GetTypeOfSymbol(params.Project, symbol)
179+
return api.GetTypeOfSymbol(ctx, params.Project, symbol)
179180
}))
180181
default:
181182
return nil, fmt.Errorf("unhandled API method %q", method)
@@ -223,12 +224,14 @@ func (api *API) LoadProject(configFileName string) (*ProjectResponse, error) {
223224
return data, nil
224225
}
225226

226-
func (api *API) GetSymbolAtPosition(projectId Handle[project.Project], fileName string, position int) (*SymbolResponse, error) {
227+
func (api *API) GetSymbolAtPosition(ctx context.Context, projectId Handle[project.Project], fileName string, position int) (*SymbolResponse, error) {
227228
project, ok := api.projects[projectId]
228229
if !ok {
229230
return nil, errors.New("project not found")
230231
}
231-
symbol, err := project.LanguageService().GetSymbolAtPosition(fileName, position)
232+
languageService, done := project.GetLanguageServiceForRequest(ctx)
233+
defer done()
234+
symbol, err := languageService.GetSymbolAtPosition(ctx, fileName, position)
232235
if err != nil || symbol == nil {
233236
return nil, err
234237
}
@@ -239,7 +242,7 @@ func (api *API) GetSymbolAtPosition(projectId Handle[project.Project], fileName
239242
return data, nil
240243
}
241244

242-
func (api *API) GetSymbolAtLocation(projectId Handle[project.Project], location Handle[ast.Node]) (*SymbolResponse, error) {
245+
func (api *API) GetSymbolAtLocation(ctx context.Context, projectId Handle[project.Project], location Handle[ast.Node]) (*SymbolResponse, error) {
243246
project, ok := api.projects[projectId]
244247
if !ok {
245248
return nil, errors.New("project not found")
@@ -262,7 +265,9 @@ func (api *API) GetSymbolAtLocation(projectId Handle[project.Project], location
262265
if node == nil {
263266
return nil, fmt.Errorf("node of kind %s not found at position %d in file %q", kind.String(), pos, sourceFile.FileName())
264267
}
265-
symbol := project.LanguageService().GetSymbolAtLocation(node)
268+
languageService, done := project.GetLanguageServiceForRequest(ctx)
269+
defer done()
270+
symbol := languageService.GetSymbolAtLocation(ctx, node)
266271
if symbol == nil {
267272
return nil, nil
268273
}
@@ -273,7 +278,7 @@ func (api *API) GetSymbolAtLocation(projectId Handle[project.Project], location
273278
return data, nil
274279
}
275280

276-
func (api *API) GetTypeOfSymbol(projectId Handle[project.Project], symbolHandle Handle[ast.Symbol]) (*TypeResponse, error) {
281+
func (api *API) GetTypeOfSymbol(ctx context.Context, projectId Handle[project.Project], symbolHandle Handle[ast.Symbol]) (*TypeResponse, error) {
277282
project, ok := api.projects[projectId]
278283
if !ok {
279284
return nil, errors.New("project not found")
@@ -284,7 +289,9 @@ func (api *API) GetTypeOfSymbol(projectId Handle[project.Project], symbolHandle
284289
if !ok {
285290
return nil, fmt.Errorf("symbol %q not found", symbolHandle)
286291
}
287-
t := project.LanguageService().GetTypeOfSymbol(symbol)
292+
languageService, done := project.GetLanguageServiceForRequest(ctx)
293+
defer done()
294+
t := languageService.GetTypeOfSymbol(ctx, symbol)
288295
if t == nil {
289296
return nil, nil
290297
}

internal/api/server.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@ package api
22

33
import (
44
"bufio"
5+
"context"
56
"encoding/binary"
67
"encoding/json"
78
"fmt"
89
"io"
10+
"strconv"
911
"sync"
1012

1113
"github.com/microsoft/typescript-go/internal/bundled"
14+
"github.com/microsoft/typescript-go/internal/core"
1215
"github.com/microsoft/typescript-go/internal/project"
1316
"github.com/microsoft/typescript-go/internal/vfs"
1417
"github.com/microsoft/typescript-go/internal/vfs/osvfs"
@@ -254,7 +257,7 @@ func (s *Server) handleRequest(method string, payload []byte) ([]byte, error) {
254257
case "echo":
255258
return payload, nil
256259
default:
257-
return s.api.HandleRequest(s.requestId, method, payload)
260+
return s.api.HandleRequest(core.WithRequestID(context.Background(), strconv.Itoa(s.requestId)), method, payload)
258261
}
259262
}
260263

internal/checker/checker_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ foo.bar;`
3939
}
4040
p := compiler.NewProgram(opts)
4141
p.BindSourceFiles()
42-
c := p.GetTypeChecker()
42+
c, done := p.GetTypeChecker(t.Context())
43+
defer done()
4344
file := p.GetSourceFile("/foo.ts")
4445
interfaceId := file.Statements.Nodes[0].Name()
4546
varId := file.Statements.Nodes[1].AsVariableStatement().DeclarationList.AsVariableDeclarationList().Declarations.Nodes[0].Name()

internal/checker/exports.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,7 @@ func (c *Checker) GetTypeOfPropertyOfContextualType(t *Type, name string) *Type
4848
func GetDeclarationModifierFlagsFromSymbol(s *ast.Symbol) ast.ModifierFlags {
4949
return getDeclarationModifierFlagsFromSymbol(s)
5050
}
51+
52+
func (c *Checker) WasCanceled() bool {
53+
return c.wasCanceled
54+
}

internal/compiler/checkerpool.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package compiler
2+
3+
import (
4+
"context"
5+
"iter"
6+
"slices"
7+
"sync"
8+
9+
"github.com/microsoft/typescript-go/internal/ast"
10+
"github.com/microsoft/typescript-go/internal/checker"
11+
"github.com/microsoft/typescript-go/internal/core"
12+
)
13+
14+
type CheckerPool interface {
15+
GetChecker(ctx context.Context) (*checker.Checker, func())
16+
GetCheckerForFile(ctx context.Context, file *ast.SourceFile) (*checker.Checker, func())
17+
GetAllCheckers(ctx context.Context) ([]*checker.Checker, func())
18+
Files(checker *checker.Checker) iter.Seq[*ast.SourceFile]
19+
}
20+
21+
type checkerPool struct {
22+
checkerCount int
23+
program *Program
24+
25+
createCheckersOnce sync.Once
26+
checkers []*checker.Checker
27+
fileAssociations map[*ast.SourceFile]*checker.Checker
28+
}
29+
30+
var _ CheckerPool = (*checkerPool)(nil)
31+
32+
func newCheckerPool(checkerCount int, program *Program) *checkerPool {
33+
pool := &checkerPool{
34+
program: program,
35+
checkerCount: checkerCount,
36+
checkers: make([]*checker.Checker, checkerCount),
37+
}
38+
39+
return pool
40+
}
41+
42+
func (p *checkerPool) GetCheckerForFile(ctx context.Context, file *ast.SourceFile) (*checker.Checker, func()) {
43+
p.createCheckers()
44+
checker := p.fileAssociations[file]
45+
return checker, noop
46+
}
47+
48+
func (p *checkerPool) GetChecker(ctx context.Context) (*checker.Checker, func()) {
49+
p.createCheckers()
50+
checker := p.checkers[0]
51+
return checker, noop
52+
}
53+
54+
func (p *checkerPool) createCheckers() {
55+
p.createCheckersOnce.Do(func() {
56+
wg := core.NewWorkGroup(p.program.singleThreaded())
57+
for i := range p.checkerCount {
58+
wg.Queue(func() {
59+
p.checkers[i] = checker.NewChecker(p.program)
60+
})
61+
}
62+
63+
wg.RunAndWait()
64+
65+
p.fileAssociations = make(map[*ast.SourceFile]*checker.Checker, len(p.program.files))
66+
for i, file := range p.program.files {
67+
p.fileAssociations[file] = p.checkers[i%p.checkerCount]
68+
}
69+
})
70+
}
71+
72+
func (p *checkerPool) GetAllCheckers(ctx context.Context) ([]*checker.Checker, func()) {
73+
p.createCheckers()
74+
return p.checkers, noop
75+
}
76+
77+
func (p *checkerPool) Files(checker *checker.Checker) iter.Seq[*ast.SourceFile] {
78+
checkerIndex := slices.Index(p.checkers, checker)
79+
return func(yield func(*ast.SourceFile) bool) {
80+
for i, file := range p.program.files {
81+
if i%p.checkerCount == checkerIndex {
82+
if !yield(file) {
83+
return
84+
}
85+
}
86+
}
87+
}
88+
}
89+
90+
func noop() {}

internal/compiler/emitHost.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package compiler
22

33
import (
4+
"context"
5+
46
"github.com/microsoft/typescript-go/internal/ast"
57
"github.com/microsoft/typescript-go/internal/core"
68
"github.com/microsoft/typescript-go/internal/printer"
@@ -32,7 +34,11 @@ func (host *emitHost) WriteFile(fileName string, text string, writeByteOrderMark
3234
}
3335

3436
func (host *emitHost) GetEmitResolver(file *ast.SourceFile, skipDiagnostics bool) printer.EmitResolver {
35-
checker := host.program.GetTypeCheckerForFile(file)
37+
// The context and done function don't matter in tsc, currently the only caller of this function.
38+
// But if this ever gets used by LSP code, we'll need to thread the context properly and pass the
39+
// done function to the caller to ensure resources are cleaned up at the end of the request.
40+
checker, done := host.program.GetTypeCheckerForFile(context.TODO(), file)
41+
defer done()
3642
return checker.GetEmitResolver(file, skipDiagnostics)
3743
}
3844

0 commit comments

Comments
 (0)