Skip to content

Commit 77c66ac

Browse files
authored
Implement runner auto-update (#3333)
Part-of: #3288
1 parent ce40af5 commit 77c66ac

File tree

20 files changed

+821
-186
lines changed

20 files changed

+821
-186
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ repos:
1111
hooks:
1212
- id: golangci-lint-full
1313
language_version: 1.25.0 # Should match runner/go.mod
14-
entry: bash -c 'cd runner && golangci-lint run --fix'
14+
entry: bash -c 'cd runner && golangci-lint run'
1515
stages: [manual]
1616
- repo: https://github.com/pre-commit/pre-commit-hooks
1717
rev: v5.0.0

runner/cmd/shim/main.go

Lines changed: 64 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,157 +5,169 @@ import (
55
"errors"
66
"fmt"
77
"io"
8-
"net/http"
98
"os"
9+
"os/signal"
1010
"path"
1111
"path/filepath"
12+
"syscall"
1213
"time"
1314

1415
"github.com/sirupsen/logrus"
15-
"github.com/urfave/cli/v2"
16+
"github.com/urfave/cli/v3"
1617

1718
"github.com/dstackai/dstack/runner/consts"
1819
"github.com/dstackai/dstack/runner/internal/common"
1920
"github.com/dstackai/dstack/runner/internal/log"
2021
"github.com/dstackai/dstack/runner/internal/shim"
2122
"github.com/dstackai/dstack/runner/internal/shim/api"
23+
"github.com/dstackai/dstack/runner/internal/shim/components"
2224
"github.com/dstackai/dstack/runner/internal/shim/dcgm"
2325
)
2426

2527
// Version is a build-time variable. The value is overridden by ldflags.
2628
var Version string
2729

2830
func main() {
31+
os.Exit(mainInner())
32+
}
33+
34+
func mainInner() int {
2935
var args shim.CLIArgs
3036
var serviceMode bool
3137

3238
const defaultLogLevel = int(logrus.InfoLevel)
3339

34-
ctx := context.Background()
35-
3640
log.DefaultEntry.Logger.SetLevel(logrus.Level(defaultLogLevel))
3741
log.DefaultEntry.Logger.SetOutput(os.Stderr)
3842

39-
app := &cli.App{
43+
cmd := &cli.Command{
4044
Name: "dstack-shim",
4145
Usage: "Starts dstack-runner or docker container.",
4246
Version: Version,
4347
Flags: []cli.Flag{
4448
/* Shim Parameters */
45-
&cli.PathFlag{
49+
&cli.StringFlag{
4650
Name: "shim-home",
4751
Usage: "Set shim's home directory",
4852
Destination: &args.Shim.HomeDir,
53+
TakesFile: true,
4954
DefaultText: path.Join("~", consts.DstackDirPath),
50-
EnvVars: []string{"DSTACK_SHIM_HOME"},
55+
Sources: cli.EnvVars("DSTACK_SHIM_HOME"),
5156
},
5257
&cli.IntFlag{
5358
Name: "shim-http-port",
5459
Usage: "Set shim's http port",
5560
Value: 10998,
5661
Destination: &args.Shim.HTTPPort,
57-
EnvVars: []string{"DSTACK_SHIM_HTTP_PORT"},
62+
Sources: cli.EnvVars("DSTACK_SHIM_HTTP_PORT"),
5863
},
5964
&cli.IntFlag{
6065
Name: "shim-log-level",
6166
Usage: "Set shim's log level",
6267
Value: defaultLogLevel,
6368
Destination: &args.Shim.LogLevel,
64-
EnvVars: []string{"DSTACK_SHIM_LOG_LEVEL"},
69+
Sources: cli.EnvVars("DSTACK_SHIM_LOG_LEVEL"),
6570
},
6671
/* Runner Parameters */
6772
&cli.StringFlag{
6873
Name: "runner-download-url",
6974
Usage: "Set runner's download URL",
7075
Destination: &args.Runner.DownloadURL,
71-
EnvVars: []string{"DSTACK_RUNNER_DOWNLOAD_URL"},
76+
Sources: cli.EnvVars("DSTACK_RUNNER_DOWNLOAD_URL"),
7277
},
73-
&cli.PathFlag{
78+
&cli.StringFlag{
7479
Name: "runner-binary-path",
7580
Usage: "Path to runner's binary",
7681
Value: consts.RunnerBinaryPath,
7782
Destination: &args.Runner.BinaryPath,
78-
EnvVars: []string{"DSTACK_RUNNER_BINARY_PATH"},
83+
TakesFile: true,
84+
Sources: cli.EnvVars("DSTACK_RUNNER_BINARY_PATH"),
7985
},
8086
&cli.IntFlag{
8187
Name: "runner-http-port",
8288
Usage: "Set runner's http port",
8389
Value: consts.RunnerHTTPPort,
8490
Destination: &args.Runner.HTTPPort,
85-
EnvVars: []string{"DSTACK_RUNNER_HTTP_PORT"},
91+
Sources: cli.EnvVars("DSTACK_RUNNER_HTTP_PORT"),
8692
},
8793
&cli.IntFlag{
8894
Name: "runner-ssh-port",
8995
Usage: "Set runner's ssh port",
9096
Value: consts.RunnerSSHPort,
9197
Destination: &args.Runner.SSHPort,
92-
EnvVars: []string{"DSTACK_RUNNER_SSH_PORT"},
98+
Sources: cli.EnvVars("DSTACK_RUNNER_SSH_PORT"),
9399
},
94100
&cli.IntFlag{
95101
Name: "runner-log-level",
96102
Usage: "Set runner's log level",
97103
Value: defaultLogLevel,
98104
Destination: &args.Runner.LogLevel,
99-
EnvVars: []string{"DSTACK_RUNNER_LOG_LEVEL"},
105+
Sources: cli.EnvVars("DSTACK_RUNNER_LOG_LEVEL"),
100106
},
101107
/* DCGM Exporter Parameters */
102108
&cli.IntFlag{
103109
Name: "dcgm-exporter-http-port",
104110
Usage: "DCGM Exporter http port",
105111
Value: 10997,
106112
Destination: &args.DCGMExporter.HTTPPort,
107-
EnvVars: []string{"DSTACK_DCGM_EXPORTER_HTTP_PORT"},
113+
Sources: cli.EnvVars("DSTACK_DCGM_EXPORTER_HTTP_PORT"),
108114
},
109115
&cli.IntFlag{
110116
Name: "dcgm-exporter-interval",
111117
Usage: "DCGM Exporter collect interval, milliseconds",
112118
Value: 5000,
113119
Destination: &args.DCGMExporter.Interval,
114-
EnvVars: []string{"DSTACK_DCGM_EXPORTER_INTERVAL"},
120+
Sources: cli.EnvVars("DSTACK_DCGM_EXPORTER_INTERVAL"),
115121
},
116122
/* DCGM Parameters */
117123
&cli.StringFlag{
118124
Name: "dcgm-address",
119125
Usage: "nv-hostengine `hostname`, e.g., `localhost`",
120126
DefaultText: "start libdcgm in embedded mode",
121127
Destination: &args.DCGM.Address,
122-
EnvVars: []string{"DSTACK_DCGM_ADDRESS"},
128+
Sources: cli.EnvVars("DSTACK_DCGM_ADDRESS"),
123129
},
124130
/* Docker Parameters */
125131
&cli.BoolFlag{
126132
Name: "privileged",
127133
Usage: "Give extended privileges to the container",
128134
Destination: &args.Docker.Privileged,
129-
EnvVars: []string{"DSTACK_DOCKER_PRIVILEGED"},
135+
Sources: cli.EnvVars("DSTACK_DOCKER_PRIVILEGED"),
130136
},
131137
&cli.StringFlag{
132138
Name: "ssh-key",
133139
Usage: "Public SSH key",
134140
Destination: &args.Docker.ConcatinatedPublicSSHKeys,
135-
EnvVars: []string{"DSTACK_PUBLIC_SSH_KEY"},
141+
Sources: cli.EnvVars("DSTACK_PUBLIC_SSH_KEY"),
136142
},
137143
&cli.StringFlag{
138144
Name: "pjrt-device",
139145
Usage: "Set the PJRT_DEVICE environment variable (e.g., TPU, GPU)",
140146
Destination: &args.Docker.PJRTDevice,
141-
EnvVars: []string{"PJRT_DEVICE"},
147+
Sources: cli.EnvVars("PJRT_DEVICE"),
142148
},
143149
/* Misc Parameters */
144150
&cli.BoolFlag{
145151
Name: "service",
146152
Usage: "Start as a service",
147153
Destination: &serviceMode,
148-
EnvVars: []string{"DSTACK_SERVICE_MODE"},
154+
Sources: cli.EnvVars("DSTACK_SERVICE_MODE"),
149155
},
150156
},
151-
Action: func(c *cli.Context) error {
157+
Action: func(ctx context.Context, cmd *cli.Command) error {
152158
return start(ctx, args, serviceMode)
153159
},
154160
}
155161

156-
if err := app.Run(os.Args); err != nil {
157-
log.Fatal(ctx, err.Error())
162+
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
163+
defer stop()
164+
165+
if err := cmd.Run(ctx, os.Args); err != nil {
166+
log.Error(ctx, err.Error())
167+
return 1
158168
}
169+
170+
return 0
159171
}
160172

161173
func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) {
@@ -191,8 +203,13 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
191203
}
192204
}()
193205

194-
if err := args.DownloadRunner(ctx); err != nil {
195-
return err
206+
runnerManager, runnerErr := components.NewRunnerManager(ctx, args.Runner.BinaryPath)
207+
if args.Runner.DownloadURL != "" {
208+
if err := runnerManager.Install(ctx, args.Runner.DownloadURL, false); err != nil {
209+
return err
210+
}
211+
} else if runnerErr != nil {
212+
return runnerErr
196213
}
197214

198215
log.Debug(ctx, "Shim", "args", args.Shim)
@@ -242,13 +259,7 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
242259
}
243260

244261
address := fmt.Sprintf("localhost:%d", args.Shim.HTTPPort)
245-
shimServer := api.NewShimServer(ctx, address, Version, dockerRunner, dcgmExporter, dcgmWrapper)
246-
247-
defer func() {
248-
shutdownCtx, cancelShutdown := context.WithTimeout(ctx, 5*time.Second)
249-
defer cancelShutdown()
250-
_ = shimServer.HttpServer.Shutdown(shutdownCtx)
251-
}()
262+
shimServer := api.NewShimServer(ctx, address, Version, dockerRunner, dcgmExporter, dcgmWrapper, runnerManager)
252263

253264
if serviceMode {
254265
if err := shim.WriteHostInfo(shimHomeDir, dockerRunner.Resources(ctx)); err != nil {
@@ -260,9 +271,25 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
260271
}
261272
}
262273

263-
if err := shimServer.HttpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
264-
return err
274+
var serveErr error
275+
serveErrCh := make(chan error)
276+
277+
go func() {
278+
if err := shimServer.Serve(); err != nil {
279+
serveErrCh <- err
280+
}
281+
}()
282+
283+
select {
284+
case serveErr = <-serveErrCh:
285+
case <-ctx.Done():
265286
}
266287

267-
return nil
288+
shutdownCtx, cancelShutdown := context.WithTimeout(ctx, 5*time.Second)
289+
defer cancelShutdown()
290+
shutdownErr := shimServer.Shutdown(shutdownCtx)
291+
if serveErr != nil {
292+
return serveErr
293+
}
294+
return shutdownErr
268295
}

0 commit comments

Comments
 (0)