Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ services:
- SHELLHUB_BILLING=${SHELLHUB_BILLING}
- ALLOW_PUBLIC_KEY_ACCESS_BELLOW_0_6_0=${SHELLHUB_ALLOW_PUBLIC_KEY_ACCESS_BELLOW_0_6_0}
- BILLING_URL=${SHELLHUB_BILLING_URL}
- SHELLHUB_TUNNELS=${SHELLHUB_TUNNELS}
- SHELLHUB_TUNNELS_DOMAIN=${SHELLHUB_TUNNELS_DOMAIN}
ports:
- "${SHELLHUB_SSH_PORT}:2222"
secrets:
Expand Down
10 changes: 8 additions & 2 deletions ssh/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ type Envs struct {
// Allows SSH to connect with an agent via a public key when the agent version is less than 0.6.0.
// Agents 0.5.x or earlier do not validate the public key request and may panic.
// Please refer to: https://github.com/shellhub-io/shellhub/issues/3453
AllowPublickeyAccessBelow060 bool `env:"ALLOW_PUBLIC_KEY_ACCESS_BELLOW_0_6_0,default=false"`
AllowPublickeyAccessBelow060 bool `env:"ALLOW_PUBLIC_KEY_ACCESS_BELLOW_0_6_0,default=false"`
Tunnels bool `env:"SHELLHUB_TUNNELS,default=false"`
TunnelsDomain string `env:"SHELLHUB_TUNNELS_DOMAIN"`
}

func main() {
Expand All @@ -45,7 +47,11 @@ func main() {
Fatal("failed to connect to redis cache")
}

tun, err := tunnel.NewTunnel("/ssh/connection", "/ssh/revdial", env.RedisURI)
tun, err := tunnel.NewTunnel("/ssh/connection", "/ssh/revdial", tunnel.Config{
Tunnels: env.Tunnels,
TunnelsDomain: env.TunnelsDomain,
RedisURI: env.RedisURI,
})
if err != nil {
log.WithError(err).
Fatal("failed to create the internalclient")
Expand Down
218 changes: 123 additions & 95 deletions ssh/pkg/tunnel/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,39 @@ func NewMessageFromError(err error) Message {
}
}

type Config struct {
// Tunnels defines if tunnel's feature is enabled.
Tunnels bool
// TunnelsDomain define the domain of tunnels feature when it's enabled.
TunnelsDomain string
// RedisURI is the redis URI connection.
RedisURI string
}

func (c Config) Validate() error {
if c.Tunnels && c.TunnelsDomain == "" {
return errors.New("tunnels feature is enabled, but tunnel's domain is empty")
}

if c.RedisURI == "" {
return errors.New("redis uri is empty")
}

return nil
}

type Tunnel struct {
Tunnel *httptunnel.Tunnel
API internalclient.Client
router *echo.Echo
}

func NewTunnel(connection, dial, redisURI string) (*Tunnel, error) {
api, err := internalclient.NewClient(internalclient.WithAsynqWorker(redisURI))
func NewTunnel(connection string, dial string, config Config) (*Tunnel, error) {
if err := config.Validate(); err != nil {
return nil, err
}

api, err := internalclient.NewClient(internalclient.WithAsynqWorker(config.RedisURI))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -160,129 +185,132 @@ func NewTunnel(connection, dial, redisURI string) (*Tunnel, error) {
return c.NoContent(http.StatusOK)
})

// The `/http/proxy` endpoint is invoked by the NGINX gateway when a tunnel URL is accessed. It processes the
// `X-Address` and `X-Path` headers, which specify the tunnel's address and the target path on the server, returning
// an error related to the connection to device or what was returned from the server inside the tunnel.
tunnel.router.Any("/http/proxy", func(c echo.Context) error {
requestID := c.Request().Header.Get("X-Request-ID")

address := c.Request().Header.Get("X-Address")
log.WithFields(log.Fields{
"request-id": requestID,
"address": address,
}).Debug("address value")

path := c.Request().Header.Get("X-Path")
log.WithFields(log.Fields{
"request-id": requestID,
"address": address,
}).Debug("path")

tun, err := tunnel.API.LookupTunnel(address)
if err != nil {
log.WithError(err).Error("failed to get the tunnel")
if config.Tunnels {
// The `/http/proxy` endpoint is invoked by the NGINX gateway when a tunnel URL is accessed. It processes the
// `X-Address` and `X-Path` headers, which specify the tunnel's address and the target path on the server, returning
// an error related to the connection to device or what was returned from the server inside the tunnel.
tunnel.router.Any("/http/proxy", func(c echo.Context) error {
requestID := c.Request().Header.Get("X-Request-ID")

address := c.Request().Header.Get("X-Address")
log.WithFields(log.Fields{
"request-id": requestID,
"address": address,
}).Debug("address value")

path := c.Request().Header.Get("X-Path")
log.WithFields(log.Fields{
"request-id": requestID,
"address": address,
}).Debug("path")

tun, err := tunnel.API.LookupTunnel(address)
if err != nil {
log.WithError(err).Error("failed to get the tunnel")

return c.JSON(http.StatusForbidden, NewMessageFromError(ErrDeviceTunnelForbidden))
}
return c.JSON(http.StatusForbidden, NewMessageFromError(ErrDeviceTunnelForbidden))
}

logger := log.WithFields(log.Fields{
"request-id": requestID,
"namespace": tun.Namespace,
"device": tun.Device,
})
logger := log.WithFields(log.Fields{
"request-id": requestID,
"namespace": tun.Namespace,
"device": tun.Device,
})

in, err := tunnel.Dial(c.Request().Context(), fmt.Sprintf("%s:%s", tun.Namespace, tun.Device))
if err != nil {
logger.WithError(err).Error("failed to dial to device")
in, err := tunnel.Dial(c.Request().Context(), fmt.Sprintf("%s:%s", tun.Namespace, tun.Device))
if err != nil {
logger.WithError(err).Error("failed to dial to device")

return c.JSON(http.StatusForbidden, NewMessageFromError(ErrDeviceTunnelDial))
}
return c.JSON(http.StatusForbidden, NewMessageFromError(ErrDeviceTunnelDial))
}

defer in.Close()
defer in.Close()

logger.Trace("new tunnel connection initialized")
defer logger.Trace("tunnel connection doned")
logger.Trace("new tunnel connection initialized")
defer logger.Trace("tunnel connection doned")

// NOTE: Connects to the HTTP proxy before doing the actual request. In this case, we are connecting to all
// hosts on the agent because we aren't specifying any host, on the port specified. The proxy route accepts
// connections for any port, but this route should only connect to the HTTP server.
req, _ := http.NewRequest(http.MethodConnect, fmt.Sprintf("/http/proxy/%s:%d", tun.Host, tun.Port), nil)
// NOTE: Connects to the HTTP proxy before doing the actual request. In this case, we are connecting to all
// hosts on the agent because we aren't specifying any host, on the port specified. The proxy route accepts
// connections for any port, but this route should only connect to the HTTP server.
req, _ := http.NewRequest(http.MethodConnect, fmt.Sprintf("/http/proxy/%s:%d", tun.Host, tun.Port), nil)

if err := req.Write(in); err != nil {
logger.WithError(err).Error("failed to write the request to the agent")
if err := req.Write(in); err != nil {
logger.WithError(err).Error("failed to write the request to the agent")

return c.JSON(http.StatusInternalServerError, NewMessageFromError(ErrDeviceTunnelWriteRequest))
}
return c.JSON(http.StatusInternalServerError, NewMessageFromError(ErrDeviceTunnelWriteRequest))
}

if resp, err := http.ReadResponse(bufio.NewReader(in), req); err != nil || resp.StatusCode != http.StatusOK {
logger.WithError(err).Error("failed to connect to HTTP port on device")
if resp, err := http.ReadResponse(bufio.NewReader(in), req); err != nil || resp.StatusCode != http.StatusOK {
logger.WithError(err).Error("failed to connect to HTTP port on device")

return c.JSON(http.StatusInternalServerError, NewMessageFromError(ErrDeviceTunnelConnect))
}
return c.JSON(http.StatusInternalServerError, NewMessageFromError(ErrDeviceTunnelConnect))
}

req = c.Request()
req.URL, err = url.Parse(path)
if err != nil {
logger.WithError(err).Error("failed to parse the path")
req = c.Request()
req.Host = strings.Join([]string{address, config.TunnelsDomain}, ".")
req.URL, err = url.Parse(path)
if err != nil {
logger.WithError(err).Error("failed to parse the path")

return c.JSON(http.StatusInternalServerError, NewMessageFromError(ErrDeviceTunnelReadResponse))
}
return c.JSON(http.StatusInternalServerError, NewMessageFromError(ErrDeviceTunnelReadResponse))
}

if err := req.Write(in); err != nil {
logger.WithError(err).Error("failed to write the request to the agent")
if err := req.Write(in); err != nil {
logger.WithError(err).Error("failed to write the request to the agent")

return c.JSON(http.StatusInternalServerError, NewMessageFromError(ErrDeviceTunnelWriteRequest))
}
return c.JSON(http.StatusInternalServerError, NewMessageFromError(ErrDeviceTunnelWriteRequest))
}

ctr := http.NewResponseController(c.Response())
out, _, err := ctr.Hijack()
if err != nil {
logger.WithError(err).Error("failed to hijact the http request")
ctr := http.NewResponseController(c.Response())
out, _, err := ctr.Hijack()
if err != nil {
logger.WithError(err).Error("failed to hijact the http request")

return c.JSON(http.StatusInternalServerError, NewMessageFromError(ErrDeviceTunnelHijackRequest))
}
return c.JSON(http.StatusInternalServerError, NewMessageFromError(ErrDeviceTunnelHijackRequest))
}

defer out.Close()
defer out.Close()

// Bidirectional copy between the client and the device.
var wg sync.WaitGroup
wg.Add(2)
// Bidirectional copy between the client and the device.
var wg sync.WaitGroup
wg.Add(2)

done := sync.OnceFunc(func() {
defer in.Close()
defer out.Close()
done := sync.OnceFunc(func() {
defer in.Close()
defer out.Close()

logger.Trace("close called on in and out connections")
})
logger.Trace("close called on in and out connections")
})

go func() {
defer done()
defer wg.Done()
go func() {
defer done()
defer wg.Done()

if _, err := io.Copy(in, out); err != nil {
logger.WithError(err).Debug("in and out done returned a error")
}
if _, err := io.Copy(in, out); err != nil {
logger.WithError(err).Debug("in and out done returned a error")
}

logger.Trace("in and out done")
}()
logger.Trace("in and out done")
}()

go func() {
defer done()
defer wg.Done()
go func() {
defer done()
defer wg.Done()

if _, err := io.Copy(out, in); err != nil {
logger.WithError(err).Debug("out and in done returned a error")
}
if _, err := io.Copy(out, in); err != nil {
logger.WithError(err).Debug("out and in done returned a error")
}

logger.Trace("out and in done")
}()
logger.Trace("out and in done")
}()

wg.Wait()
wg.Wait()

logger.Debug("http proxy is done")
logger.Debug("http proxy is done")

return nil
})
return nil
})
}

tunnel.router.GET("/healthcheck", func(c echo.Context) error {
return c.String(http.StatusOK, "OK")
Expand Down
Loading