Skip to content

Commit 8a4c57e

Browse files
committed
Implement runner auto-update
Part-of: #3288
1 parent 41fcee2 commit 8a4c57e

File tree

20 files changed

+803
-179
lines changed

20 files changed

+803
-179
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ repos:
1010
rev: v1.62.0 # Should match .github/workflows/build-artifacts.yml
1111
hooks:
1212
- id: golangci-lint-full
13-
language_version: 1.23.8 # Should match runner/go.mod
13+
language_version: 1.25.0 # Should match runner/go.mod
1414
entry: bash -c 'cd runner && golangci-lint run'
1515
stages: [manual]
1616
- repo: https://github.com/pre-commit/pre-commit-hooks

runner/cmd/shim/main.go

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,156 +5,168 @@ import (
55
"errors"
66
"fmt"
77
"io"
8-
"net/http"
98
"os"
9+
"os/signal"
1010
"path"
1111
"path/filepath"
12+
"syscall"
1213
"time"
1314

1415
"github.com/dstackai/dstack/runner/consts"
1516
"github.com/dstackai/dstack/runner/internal/common"
1617
"github.com/dstackai/dstack/runner/internal/log"
1718
"github.com/dstackai/dstack/runner/internal/shim"
1819
"github.com/dstackai/dstack/runner/internal/shim/api"
20+
"github.com/dstackai/dstack/runner/internal/shim/components"
1921
"github.com/dstackai/dstack/runner/internal/shim/dcgm"
2022
"github.com/sirupsen/logrus"
21-
"github.com/urfave/cli/v2"
23+
"github.com/urfave/cli/v3"
2224
)
2325

2426
// Version is a build-time variable. The value is overridden by ldflags.
2527
var Version string
2628

2729
func main() {
30+
os.Exit(mainInner())
31+
}
32+
33+
func mainInner() int {
2834
var args shim.CLIArgs
2935
var serviceMode bool
3036

3137
const defaultLogLevel = int(logrus.InfoLevel)
3238

33-
ctx := context.Background()
34-
3539
log.DefaultEntry.Logger.SetLevel(logrus.Level(defaultLogLevel))
3640
log.DefaultEntry.Logger.SetOutput(os.Stderr)
3741

38-
app := &cli.App{
42+
cmd := &cli.Command{
3943
Name: "dstack-shim",
4044
Usage: "Starts dstack-runner or docker container.",
4145
Version: Version,
4246
Flags: []cli.Flag{
4347
/* Shim Parameters */
44-
&cli.PathFlag{
48+
&cli.StringFlag{
4549
Name: "shim-home",
4650
Usage: "Set shim's home directory",
4751
Destination: &args.Shim.HomeDir,
52+
TakesFile: true,
4853
DefaultText: path.Join("~", consts.DstackDirPath),
49-
EnvVars: []string{"DSTACK_SHIM_HOME"},
54+
Sources: cli.EnvVars("DSTACK_SHIM_HOME"),
5055
},
5156
&cli.IntFlag{
5257
Name: "shim-http-port",
5358
Usage: "Set shim's http port",
5459
Value: 10998,
5560
Destination: &args.Shim.HTTPPort,
56-
EnvVars: []string{"DSTACK_SHIM_HTTP_PORT"},
61+
Sources: cli.EnvVars("DSTACK_SHIM_HTTP_PORT"),
5762
},
5863
&cli.IntFlag{
5964
Name: "shim-log-level",
6065
Usage: "Set shim's log level",
6166
Value: defaultLogLevel,
6267
Destination: &args.Shim.LogLevel,
63-
EnvVars: []string{"DSTACK_SHIM_LOG_LEVEL"},
68+
Sources: cli.EnvVars("DSTACK_SHIM_LOG_LEVEL"),
6469
},
6570
/* Runner Parameters */
6671
&cli.StringFlag{
6772
Name: "runner-download-url",
6873
Usage: "Set runner's download URL",
6974
Destination: &args.Runner.DownloadURL,
70-
EnvVars: []string{"DSTACK_RUNNER_DOWNLOAD_URL"},
75+
Sources: cli.EnvVars("DSTACK_RUNNER_DOWNLOAD_URL"),
7176
},
72-
&cli.PathFlag{
77+
&cli.StringFlag{
7378
Name: "runner-binary-path",
7479
Usage: "Path to runner's binary",
7580
Value: consts.RunnerBinaryPath,
7681
Destination: &args.Runner.BinaryPath,
77-
EnvVars: []string{"DSTACK_RUNNER_BINARY_PATH"},
82+
TakesFile: true,
83+
Sources: cli.EnvVars("DSTACK_RUNNER_BINARY_PATH"),
7884
},
7985
&cli.IntFlag{
8086
Name: "runner-http-port",
8187
Usage: "Set runner's http port",
8288
Value: consts.RunnerHTTPPort,
8389
Destination: &args.Runner.HTTPPort,
84-
EnvVars: []string{"DSTACK_RUNNER_HTTP_PORT"},
90+
Sources: cli.EnvVars("DSTACK_RUNNER_HTTP_PORT"),
8591
},
8692
&cli.IntFlag{
8793
Name: "runner-ssh-port",
8894
Usage: "Set runner's ssh port",
8995
Value: consts.RunnerSSHPort,
9096
Destination: &args.Runner.SSHPort,
91-
EnvVars: []string{"DSTACK_RUNNER_SSH_PORT"},
97+
Sources: cli.EnvVars("DSTACK_RUNNER_SSH_PORT"),
9298
},
9399
&cli.IntFlag{
94100
Name: "runner-log-level",
95101
Usage: "Set runner's log level",
96102
Value: defaultLogLevel,
97103
Destination: &args.Runner.LogLevel,
98-
EnvVars: []string{"DSTACK_RUNNER_LOG_LEVEL"},
104+
Sources: cli.EnvVars("DSTACK_RUNNER_LOG_LEVEL"),
99105
},
100106
/* DCGM Exporter Parameters */
101107
&cli.IntFlag{
102108
Name: "dcgm-exporter-http-port",
103109
Usage: "DCGM Exporter http port",
104110
Value: 10997,
105111
Destination: &args.DCGMExporter.HTTPPort,
106-
EnvVars: []string{"DSTACK_DCGM_EXPORTER_HTTP_PORT"},
112+
Sources: cli.EnvVars("DSTACK_DCGM_EXPORTER_HTTP_PORT"),
107113
},
108114
&cli.IntFlag{
109115
Name: "dcgm-exporter-interval",
110116
Usage: "DCGM Exporter collect interval, milliseconds",
111117
Value: 5000,
112118
Destination: &args.DCGMExporter.Interval,
113-
EnvVars: []string{"DSTACK_DCGM_EXPORTER_INTERVAL"},
119+
Sources: cli.EnvVars("DSTACK_DCGM_EXPORTER_INTERVAL"),
114120
},
115121
/* DCGM Parameters */
116122
&cli.StringFlag{
117123
Name: "dcgm-address",
118124
Usage: "nv-hostengine `hostname`, e.g., `localhost`",
119125
DefaultText: "start libdcgm in embedded mode",
120126
Destination: &args.DCGM.Address,
121-
EnvVars: []string{"DSTACK_DCGM_ADDRESS"},
127+
Sources: cli.EnvVars("DSTACK_DCGM_ADDRESS"),
122128
},
123129
/* Docker Parameters */
124130
&cli.BoolFlag{
125131
Name: "privileged",
126132
Usage: "Give extended privileges to the container",
127133
Destination: &args.Docker.Privileged,
128-
EnvVars: []string{"DSTACK_DOCKER_PRIVILEGED"},
134+
Sources: cli.EnvVars("DSTACK_DOCKER_PRIVILEGED"),
129135
},
130136
&cli.StringFlag{
131137
Name: "ssh-key",
132138
Usage: "Public SSH key",
133139
Destination: &args.Docker.ConcatinatedPublicSSHKeys,
134-
EnvVars: []string{"DSTACK_PUBLIC_SSH_KEY"},
140+
Sources: cli.EnvVars("DSTACK_PUBLIC_SSH_KEY"),
135141
},
136142
&cli.StringFlag{
137143
Name: "pjrt-device",
138144
Usage: "Set the PJRT_DEVICE environment variable (e.g., TPU, GPU)",
139145
Destination: &args.Docker.PJRTDevice,
140-
EnvVars: []string{"PJRT_DEVICE"},
146+
Sources: cli.EnvVars("PJRT_DEVICE"),
141147
},
142148
/* Misc Parameters */
143149
&cli.BoolFlag{
144150
Name: "service",
145151
Usage: "Start as a service",
146152
Destination: &serviceMode,
147-
EnvVars: []string{"DSTACK_SERVICE_MODE"},
153+
Sources: cli.EnvVars("DSTACK_SERVICE_MODE"),
148154
},
149155
},
150-
Action: func(c *cli.Context) error {
156+
Action: func(ctx context.Context, cmd *cli.Command) error {
151157
return start(ctx, args, serviceMode)
152158
},
153159
}
154160

155-
if err := app.Run(os.Args); err != nil {
156-
log.Fatal(ctx, err.Error())
161+
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
162+
defer stop()
163+
164+
if err := cmd.Run(ctx, os.Args); err != nil {
165+
log.Error(ctx, err.Error())
166+
return 1
157167
}
168+
169+
return 0
158170
}
159171

160172
func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error) {
@@ -190,8 +202,13 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
190202
}
191203
}()
192204

193-
if err := args.DownloadRunner(ctx); err != nil {
194-
return err
205+
runnerManager, runnerErr := components.NewRunnerManager(ctx, args.Runner.BinaryPath)
206+
if args.Runner.DownloadURL != "" {
207+
if err := runnerManager.Install(ctx, args.Runner.DownloadURL, false); err != nil {
208+
return err
209+
}
210+
} else if runnerErr != nil {
211+
return runnerErr
195212
}
196213

197214
log.Debug(ctx, "Shim", "args", args.Shim)
@@ -241,13 +258,7 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
241258
}
242259

243260
address := fmt.Sprintf("localhost:%d", args.Shim.HTTPPort)
244-
shimServer := api.NewShimServer(ctx, address, Version, dockerRunner, dcgmExporter, dcgmWrapper)
245-
246-
defer func() {
247-
shutdownCtx, cancelShutdown := context.WithTimeout(ctx, 5*time.Second)
248-
defer cancelShutdown()
249-
_ = shimServer.HttpServer.Shutdown(shutdownCtx)
250-
}()
261+
shimServer := api.NewShimServer(ctx, address, Version, dockerRunner, dcgmExporter, dcgmWrapper, runnerManager)
251262

252263
if serviceMode {
253264
if err := shim.WriteHostInfo(shimHomeDir, dockerRunner.Resources(ctx)); err != nil {
@@ -259,9 +270,14 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
259270
}
260271
}
261272

262-
if err := shimServer.HttpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
263-
return err
264-
}
273+
go func() {
274+
if err := shimServer.Serve(); err != nil {
275+
log.Error(ctx, "serve", "err", err)
276+
}
277+
}()
265278

266-
return nil
279+
<-ctx.Done()
280+
shutdownCtx, cancelShutdown := context.WithTimeout(ctx, 5*time.Second)
281+
defer cancelShutdown()
282+
return shimServer.Shutdown(shutdownCtx)
267283
}

0 commit comments

Comments
 (0)