@@ -5,156 +5,168 @@ import (
55 "errors"
66 "fmt"
77 "io"
8- "net/http"
98 "os"
9+ "os/signal"
1010 "path"
1111 "path/filepath"
12+ "syscall"
1213 "time"
1314
1415 "github.com/dstackai/dstack/runner/consts"
1516 "github.com/dstackai/dstack/runner/internal/common"
1617 "github.com/dstackai/dstack/runner/internal/log"
1718 "github.com/dstackai/dstack/runner/internal/shim"
1819 "github.com/dstackai/dstack/runner/internal/shim/api"
20+ "github.com/dstackai/dstack/runner/internal/shim/components"
1921 "github.com/dstackai/dstack/runner/internal/shim/dcgm"
2022 "github.com/sirupsen/logrus"
21- "github.com/urfave/cli/v2 "
23+ "github.com/urfave/cli/v3 "
2224)
2325
2426// Version is a build-time variable. The value is overridden by ldflags.
2527var Version string
2628
2729func main () {
30+ os .Exit (mainInner ())
31+ }
32+
33+ func mainInner () int {
2834 var args shim.CLIArgs
2935 var serviceMode bool
3036
3137 const defaultLogLevel = int (logrus .InfoLevel )
3238
33- ctx := context .Background ()
34-
3539 log .DefaultEntry .Logger .SetLevel (logrus .Level (defaultLogLevel ))
3640 log .DefaultEntry .Logger .SetOutput (os .Stderr )
3741
38- app := & cli.App {
42+ cmd := & cli.Command {
3943 Name : "dstack-shim" ,
4044 Usage : "Starts dstack-runner or docker container." ,
4145 Version : Version ,
4246 Flags : []cli.Flag {
4347 /* Shim Parameters */
44- & cli.PathFlag {
48+ & cli.StringFlag {
4549 Name : "shim-home" ,
4650 Usage : "Set shim's home directory" ,
4751 Destination : & args .Shim .HomeDir ,
52+ TakesFile : true ,
4853 DefaultText : path .Join ("~" , consts .DstackDirPath ),
49- EnvVars : [] string { "DSTACK_SHIM_HOME" } ,
54+ Sources : cli . EnvVars ( "DSTACK_SHIM_HOME" ) ,
5055 },
5156 & cli.IntFlag {
5257 Name : "shim-http-port" ,
5358 Usage : "Set shim's http port" ,
5459 Value : 10998 ,
5560 Destination : & args .Shim .HTTPPort ,
56- EnvVars : [] string { "DSTACK_SHIM_HTTP_PORT" } ,
61+ Sources : cli . EnvVars ( "DSTACK_SHIM_HTTP_PORT" ) ,
5762 },
5863 & cli.IntFlag {
5964 Name : "shim-log-level" ,
6065 Usage : "Set shim's log level" ,
6166 Value : defaultLogLevel ,
6267 Destination : & args .Shim .LogLevel ,
63- EnvVars : [] string { "DSTACK_SHIM_LOG_LEVEL" } ,
68+ Sources : cli . EnvVars ( "DSTACK_SHIM_LOG_LEVEL" ) ,
6469 },
6570 /* Runner Parameters */
6671 & cli.StringFlag {
6772 Name : "runner-download-url" ,
6873 Usage : "Set runner's download URL" ,
6974 Destination : & args .Runner .DownloadURL ,
70- EnvVars : [] string { "DSTACK_RUNNER_DOWNLOAD_URL" } ,
75+ Sources : cli . EnvVars ( "DSTACK_RUNNER_DOWNLOAD_URL" ) ,
7176 },
72- & cli.PathFlag {
77+ & cli.StringFlag {
7378 Name : "runner-binary-path" ,
7479 Usage : "Path to runner's binary" ,
7580 Value : consts .RunnerBinaryPath ,
7681 Destination : & args .Runner .BinaryPath ,
77- EnvVars : []string {"DSTACK_RUNNER_BINARY_PATH" },
82+ TakesFile : true ,
83+ Sources : cli .EnvVars ("DSTACK_RUNNER_BINARY_PATH" ),
7884 },
7985 & cli.IntFlag {
8086 Name : "runner-http-port" ,
8187 Usage : "Set runner's http port" ,
8288 Value : consts .RunnerHTTPPort ,
8389 Destination : & args .Runner .HTTPPort ,
84- EnvVars : [] string { "DSTACK_RUNNER_HTTP_PORT" } ,
90+ Sources : cli . EnvVars ( "DSTACK_RUNNER_HTTP_PORT" ) ,
8591 },
8692 & cli.IntFlag {
8793 Name : "runner-ssh-port" ,
8894 Usage : "Set runner's ssh port" ,
8995 Value : consts .RunnerSSHPort ,
9096 Destination : & args .Runner .SSHPort ,
91- EnvVars : [] string { "DSTACK_RUNNER_SSH_PORT" } ,
97+ Sources : cli . EnvVars ( "DSTACK_RUNNER_SSH_PORT" ) ,
9298 },
9399 & cli.IntFlag {
94100 Name : "runner-log-level" ,
95101 Usage : "Set runner's log level" ,
96102 Value : defaultLogLevel ,
97103 Destination : & args .Runner .LogLevel ,
98- EnvVars : [] string { "DSTACK_RUNNER_LOG_LEVEL" } ,
104+ Sources : cli . EnvVars ( "DSTACK_RUNNER_LOG_LEVEL" ) ,
99105 },
100106 /* DCGM Exporter Parameters */
101107 & cli.IntFlag {
102108 Name : "dcgm-exporter-http-port" ,
103109 Usage : "DCGM Exporter http port" ,
104110 Value : 10997 ,
105111 Destination : & args .DCGMExporter .HTTPPort ,
106- EnvVars : [] string { "DSTACK_DCGM_EXPORTER_HTTP_PORT" } ,
112+ Sources : cli . EnvVars ( "DSTACK_DCGM_EXPORTER_HTTP_PORT" ) ,
107113 },
108114 & cli.IntFlag {
109115 Name : "dcgm-exporter-interval" ,
110116 Usage : "DCGM Exporter collect interval, milliseconds" ,
111117 Value : 5000 ,
112118 Destination : & args .DCGMExporter .Interval ,
113- EnvVars : [] string { "DSTACK_DCGM_EXPORTER_INTERVAL" } ,
119+ Sources : cli . EnvVars ( "DSTACK_DCGM_EXPORTER_INTERVAL" ) ,
114120 },
115121 /* DCGM Parameters */
116122 & cli.StringFlag {
117123 Name : "dcgm-address" ,
118124 Usage : "nv-hostengine `hostname`, e.g., `localhost`" ,
119125 DefaultText : "start libdcgm in embedded mode" ,
120126 Destination : & args .DCGM .Address ,
121- EnvVars : [] string { "DSTACK_DCGM_ADDRESS" } ,
127+ Sources : cli . EnvVars ( "DSTACK_DCGM_ADDRESS" ) ,
122128 },
123129 /* Docker Parameters */
124130 & cli.BoolFlag {
125131 Name : "privileged" ,
126132 Usage : "Give extended privileges to the container" ,
127133 Destination : & args .Docker .Privileged ,
128- EnvVars : [] string { "DSTACK_DOCKER_PRIVILEGED" } ,
134+ Sources : cli . EnvVars ( "DSTACK_DOCKER_PRIVILEGED" ) ,
129135 },
130136 & cli.StringFlag {
131137 Name : "ssh-key" ,
132138 Usage : "Public SSH key" ,
133139 Destination : & args .Docker .ConcatinatedPublicSSHKeys ,
134- EnvVars : [] string { "DSTACK_PUBLIC_SSH_KEY" } ,
140+ Sources : cli . EnvVars ( "DSTACK_PUBLIC_SSH_KEY" ) ,
135141 },
136142 & cli.StringFlag {
137143 Name : "pjrt-device" ,
138144 Usage : "Set the PJRT_DEVICE environment variable (e.g., TPU, GPU)" ,
139145 Destination : & args .Docker .PJRTDevice ,
140- EnvVars : [] string { "PJRT_DEVICE" } ,
146+ Sources : cli . EnvVars ( "PJRT_DEVICE" ) ,
141147 },
142148 /* Misc Parameters */
143149 & cli.BoolFlag {
144150 Name : "service" ,
145151 Usage : "Start as a service" ,
146152 Destination : & serviceMode ,
147- EnvVars : [] string { "DSTACK_SERVICE_MODE" } ,
153+ Sources : cli . EnvVars ( "DSTACK_SERVICE_MODE" ) ,
148154 },
149155 },
150- Action : func (c * cli.Context ) error {
156+ Action : func (ctx context. Context , cmd * cli.Command ) error {
151157 return start (ctx , args , serviceMode )
152158 },
153159 }
154160
155- if err := app .Run (os .Args ); err != nil {
156- log .Fatal (ctx , err .Error ())
161+ ctx , stop := signal .NotifyContext (context .Background (), os .Interrupt , syscall .SIGTERM )
162+ defer stop ()
163+
164+ if err := cmd .Run (ctx , os .Args ); err != nil {
165+ log .Error (ctx , err .Error ())
166+ return 1
157167 }
168+
169+ return 0
158170}
159171
160172func start (ctx context.Context , args shim.CLIArgs , serviceMode bool ) (err error ) {
@@ -190,8 +202,13 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
190202 }
191203 }()
192204
193- if err := args .DownloadRunner (ctx ); err != nil {
194- return err
205+ runnerManager , runnerErr := components .NewRunnerManager (ctx , args .Runner .BinaryPath )
206+ if args .Runner .DownloadURL != "" {
207+ if err := runnerManager .Install (ctx , args .Runner .DownloadURL , false ); err != nil {
208+ return err
209+ }
210+ } else if runnerErr != nil {
211+ return runnerErr
195212 }
196213
197214 log .Debug (ctx , "Shim" , "args" , args .Shim )
@@ -241,13 +258,7 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
241258 }
242259
243260 address := fmt .Sprintf ("localhost:%d" , args .Shim .HTTPPort )
244- shimServer := api .NewShimServer (ctx , address , Version , dockerRunner , dcgmExporter , dcgmWrapper )
245-
246- defer func () {
247- shutdownCtx , cancelShutdown := context .WithTimeout (ctx , 5 * time .Second )
248- defer cancelShutdown ()
249- _ = shimServer .HttpServer .Shutdown (shutdownCtx )
250- }()
261+ shimServer := api .NewShimServer (ctx , address , Version , dockerRunner , dcgmExporter , dcgmWrapper , runnerManager )
251262
252263 if serviceMode {
253264 if err := shim .WriteHostInfo (shimHomeDir , dockerRunner .Resources (ctx )); err != nil {
@@ -259,9 +270,14 @@ func start(ctx context.Context, args shim.CLIArgs, serviceMode bool) (err error)
259270 }
260271 }
261272
262- if err := shimServer .HttpServer .ListenAndServe (); err != nil && ! errors .Is (err , http .ErrServerClosed ) {
263- return err
264- }
273+ go func () {
274+ if err := shimServer .Serve (); err != nil {
275+ log .Error (ctx , "serve" , "err" , err )
276+ }
277+ }()
265278
266- return nil
279+ <- ctx .Done ()
280+ shutdownCtx , cancelShutdown := context .WithTimeout (ctx , 5 * time .Second )
281+ defer cancelShutdown ()
282+ return shimServer .Shutdown (shutdownCtx )
267283}
0 commit comments