@@ -5,157 +5,169 @@ import (
55 "errors"
66 "fmt"
77 "io"
8- "net/http"
98 "os"
9+ "os/signal"
1010 "path"
1111 "path/filepath"
12+ "syscall"
1213 "time"
1314
1415 "github.com/sirupsen/logrus"
15- "github.com/urfave/cli/v2 "
16+ "github.com/urfave/cli/v3 "
1617
1718 "github.com/dstackai/dstack/runner/consts"
1819 "github.com/dstackai/dstack/runner/internal/common"
1920 "github.com/dstackai/dstack/runner/internal/log"
2021 "github.com/dstackai/dstack/runner/internal/shim"
2122 "github.com/dstackai/dstack/runner/internal/shim/api"
23+ "github.com/dstackai/dstack/runner/internal/shim/components"
2224 "github.com/dstackai/dstack/runner/internal/shim/dcgm"
2325)
2426
2527// Version is a build-time variable. The value is overridden by ldflags.
2628var Version string
2729
2830func main () {
31+ os .Exit (mainInner ())
32+ }
33+
34+ func mainInner () int {
2935 var args shim.CLIArgs
3036 var serviceMode bool
3137
3238 const defaultLogLevel = int (logrus .InfoLevel )
3339
34- ctx := context .Background ()
35-
3640 log .DefaultEntry .Logger .SetLevel (logrus .Level (defaultLogLevel ))
3741 log .DefaultEntry .Logger .SetOutput (os .Stderr )
3842
39- app := & cli.App {
43+ cmd := & cli.Command {
4044 Name : "dstack-shim" ,
4145 Usage : "Starts dstack-runner or docker container." ,
4246 Version : Version ,
4347 Flags : []cli.Flag {
4448 /* Shim Parameters */
45- & cli.PathFlag {
49+ & cli.StringFlag {
4650 Name : "shim-home" ,
4751 Usage : "Set shim's home directory" ,
4852 Destination : & args .Shim .HomeDir ,
53+ TakesFile : true ,
4954 DefaultText : path .Join ("~" , consts .DstackDirPath ),
50- EnvVars : [] string { "DSTACK_SHIM_HOME" } ,
55+ Sources : cli . EnvVars ( "DSTACK_SHIM_HOME" ) ,
5156 },
5257 & cli.IntFlag {
5358 Name : "shim-http-port" ,
5459 Usage : "Set shim's http port" ,
5560 Value : 10998 ,
5661 Destination : & args .Shim .HTTPPort ,
57- EnvVars : [] string { "DSTACK_SHIM_HTTP_PORT" } ,
62+ Sources : cli . EnvVars ( "DSTACK_SHIM_HTTP_PORT" ) ,
5863 },
5964 & cli.IntFlag {
6065 Name : "shim-log-level" ,
6166 Usage : "Set shim's log level" ,
6267 Value : defaultLogLevel ,
6368 Destination : & args .Shim .LogLevel ,
64- EnvVars : [] string { "DSTACK_SHIM_LOG_LEVEL" } ,
69+ Sources : cli . EnvVars ( "DSTACK_SHIM_LOG_LEVEL" ) ,
6570 },
6671 /* Runner Parameters */
6772 & cli.StringFlag {
6873 Name : "runner-download-url" ,
6974 Usage : "Set runner's download URL" ,
7075 Destination : & args .Runner .DownloadURL ,
71- EnvVars : [] string { "DSTACK_RUNNER_DOWNLOAD_URL" } ,
76+ Sources : cli . EnvVars ( "DSTACK_RUNNER_DOWNLOAD_URL" ) ,
7277 },
73- & cli.PathFlag {
78+ & cli.StringFlag {
7479 Name : "runner-binary-path" ,
7580 Usage : "Path to runner's binary" ,
7681 Value : consts .RunnerBinaryPath ,
7782 Destination : & args .Runner .BinaryPath ,
78- EnvVars : []string {"DSTACK_RUNNER_BINARY_PATH" },
83+ TakesFile : true ,
84+ Sources : cli .EnvVars ("DSTACK_RUNNER_BINARY_PATH" ),
7985 },
8086 & cli.IntFlag {
8187 Name : "runner-http-port" ,
8288 Usage : "Set runner's http port" ,
8389 Value : consts .RunnerHTTPPort ,
8490 Destination : & args .Runner .HTTPPort ,
85- EnvVars : [] string { "DSTACK_RUNNER_HTTP_PORT" } ,
91+ Sources : cli . EnvVars ( "DSTACK_RUNNER_HTTP_PORT" ) ,
8692 },
8793 & cli.IntFlag {
8894 Name : "runner-ssh-port" ,
8995 Usage : "Set runner's ssh port" ,
9096 Value : consts .RunnerSSHPort ,
9197 Destination : & args .Runner .SSHPort ,
92- EnvVars : [] string { "DSTACK_RUNNER_SSH_PORT" } ,
98+ Sources : cli . EnvVars ( "DSTACK_RUNNER_SSH_PORT" ) ,
9399 },
94100 & cli.IntFlag {
95101 Name : "runner-log-level" ,
96102 Usage : "Set runner's log level" ,
97103 Value : defaultLogLevel ,
98104 Destination : & args .Runner .LogLevel ,
99- EnvVars : [] string { "DSTACK_RUNNER_LOG_LEVEL" } ,
105+ Sources : cli . EnvVars ( "DSTACK_RUNNER_LOG_LEVEL" ) ,
100106 },
101107 /* DCGM Exporter Parameters */
102108 & cli.IntFlag {
103109 Name : "dcgm-exporter-http-port" ,
104110 Usage : "DCGM Exporter http port" ,
105111 Value : 10997 ,
106112 Destination : & args .DCGMExporter .HTTPPort ,
107- EnvVars : [] string { "DSTACK_DCGM_EXPORTER_HTTP_PORT" } ,
113+ Sources : cli . EnvVars ( "DSTACK_DCGM_EXPORTER_HTTP_PORT" ) ,
108114 },
109115 & cli.IntFlag {
110116 Name : "dcgm-exporter-interval" ,
111117 Usage : "DCGM Exporter collect interval, milliseconds" ,
112118 Value : 5000 ,
113119 Destination : & args .DCGMExporter .Interval ,
114- EnvVars : [] string { "DSTACK_DCGM_EXPORTER_INTERVAL" } ,
120+ Sources : cli . EnvVars ( "DSTACK_DCGM_EXPORTER_INTERVAL" ) ,
115121 },
116122 /* DCGM Parameters */
117123 & cli.StringFlag {
118124 Name : "dcgm-address" ,
119125 Usage : "nv-hostengine `hostname`, e.g., `localhost`" ,
120126 DefaultText : "start libdcgm in embedded mode" ,
121127 Destination : & args .DCGM .Address ,
122- EnvVars : [] string { "DSTACK_DCGM_ADDRESS" } ,
128+ Sources : cli . EnvVars ( "DSTACK_DCGM_ADDRESS" ) ,
123129 },
124130 /* Docker Parameters */
125131 & cli.BoolFlag {
126132 Name : "privileged" ,
127133 Usage : "Give extended privileges to the container" ,
128134 Destination : & args .Docker .Privileged ,
129- EnvVars : [] string { "DSTACK_DOCKER_PRIVILEGED" } ,
135+ Sources : cli . EnvVars ( "DSTACK_DOCKER_PRIVILEGED" ) ,
130136 },
131137 & cli.StringFlag {
132138 Name : "ssh-key" ,
133139 Usage : "Public SSH key" ,
134140 Destination : & args .Docker .ConcatinatedPublicSSHKeys ,
135- EnvVars : [] string { "DSTACK_PUBLIC_SSH_KEY" } ,
141+ Sources : cli . EnvVars ( "DSTACK_PUBLIC_SSH_KEY" ) ,
136142 },
137143 & cli.StringFlag {
138144 Name : "pjrt-device" ,
139145 Usage : "Set the PJRT_DEVICE environment variable (e.g., TPU, GPU)" ,
140146 Destination : & args .Docker .PJRTDevice ,
141- EnvVars : [] string { "PJRT_DEVICE" } ,
147+ Sources : cli . EnvVars ( "PJRT_DEVICE" ) ,
142148 },
143149 /* Misc Parameters */
144150 & cli.BoolFlag {
145151 Name : "service" ,
146152 Usage : "Start as a service" ,
147153 Destination : & serviceMode ,
148- EnvVars : [] string { "DSTACK_SERVICE_MODE" } ,
154+ Sources : cli . EnvVars ( "DSTACK_SERVICE_MODE" ) ,
149155 },
150156 },
151- Action : func (c * cli.Context ) error {
157+ Action : func (ctx context. Context , cmd * cli.Command ) error {
152158 return start (ctx , args , serviceMode )
153159 },
154160 }
155161
156- if err := app .Run (os .Args ); err != nil {
157- log .Fatal (ctx , err .Error ())
162+ ctx , stop := signal .NotifyContext (context .Background (), os .Interrupt , syscall .SIGTERM )
163+ defer stop ()
164+
165+ if err := cmd .Run (ctx , os .Args ); err != nil {
166+ log .Error (ctx , err .Error ())
167+ return 1
158168 }
169+
170+ return 0
159171}
160172
161173func start (ctx context.Context , args shim.CLIArgs , serviceMode bool ) (err error ) {
@@ -191,8 +203,13 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
191203 }
192204 }()
193205
194- if err := args .DownloadRunner (ctx ); err != nil {
195- return err
206+ runnerManager , runnerErr := components .NewRunnerManager (ctx , args .Runner .BinaryPath )
207+ if args .Runner .DownloadURL != "" {
208+ if err := runnerManager .Install (ctx , args .Runner .DownloadURL , false ); err != nil {
209+ return err
210+ }
211+ } else if runnerErr != nil {
212+ return runnerErr
196213 }
197214
198215 log .Debug (ctx , "Shim" , "args" , args .Shim )
@@ -242,13 +259,7 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
242259 }
243260
244261 address := fmt .Sprintf ("localhost:%d" , args .Shim .HTTPPort )
245- shimServer := api .NewShimServer (ctx , address , Version , dockerRunner , dcgmExporter , dcgmWrapper )
246-
247- defer func () {
248- shutdownCtx , cancelShutdown := context .WithTimeout (ctx , 5 * time .Second )
249- defer cancelShutdown ()
250- _ = shimServer .HttpServer .Shutdown (shutdownCtx )
251- }()
262+ shimServer := api .NewShimServer (ctx , address , Version , dockerRunner , dcgmExporter , dcgmWrapper , runnerManager )
252263
253264 if serviceMode {
254265 if err := shim .WriteHostInfo (shimHomeDir , dockerRunner .Resources (ctx )); err != nil {
@@ -260,9 +271,25 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
260271 }
261272 }
262273
263- if err := shimServer .HttpServer .ListenAndServe (); err != nil && ! errors .Is (err , http .ErrServerClosed ) {
264- return err
274+ var serveErr error
275+ serveErrCh := make (chan error )
276+
277+ go func () {
278+ if err := shimServer .Serve (); err != nil {
279+ serveErrCh <- err
280+ }
281+ }()
282+
283+ select {
284+ case serveErr = <- serveErrCh :
285+ case <- ctx .Done ():
265286 }
266287
267- return nil
288+ shutdownCtx , cancelShutdown := context .WithTimeout (ctx , 5 * time .Second )
289+ defer cancelShutdown ()
290+ shutdownErr := shimServer .Shutdown (shutdownCtx )
291+ if serveErr != nil {
292+ return serveErr
293+ }
294+ return shutdownErr
268295}
0 commit comments