Skip to content

Commit 59de9cd

Browse files
Merge pull request #275 from gabriel-samfira/add-event-stream
Add event stream
2 parents ca7f20b + 2554f70 commit 59de9cd

File tree

111 files changed

+10229
-4862
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

111 files changed

+10229
-4862
lines changed

apiserver/controllers/controllers.go

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828

2929
gErrors "github.com/cloudbase/garm-provider-common/errors"
3030
"github.com/cloudbase/garm-provider-common/util"
31+
"github.com/cloudbase/garm/apiserver/events"
3132
"github.com/cloudbase/garm/apiserver/params"
3233
"github.com/cloudbase/garm/auth"
3334
"github.com/cloudbase/garm/metrics"
@@ -163,6 +164,43 @@ func (a *APIController) WebhookHandler(w http.ResponseWriter, r *http.Request) {
163164
}
164165
}
165166

167+
func (a *APIController) EventsHandler(w http.ResponseWriter, r *http.Request) {
168+
ctx := r.Context()
169+
if !auth.IsAdmin(ctx) {
170+
w.WriteHeader(http.StatusForbidden)
171+
if _, err := w.Write([]byte("events are available to admin users")); err != nil {
172+
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to encode response")
173+
}
174+
return
175+
}
176+
177+
conn, err := a.upgrader.Upgrade(w, r, nil)
178+
if err != nil {
179+
slog.With(slog.Any("error", err)).ErrorContext(ctx, "error upgrading to websockets")
180+
return
181+
}
182+
defer conn.Close()
183+
184+
wsClient, err := wsWriter.NewClient(ctx, conn)
185+
if err != nil {
186+
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to create new client")
187+
return
188+
}
189+
defer wsClient.Stop()
190+
191+
eventHandler, err := events.NewHandler(ctx, wsClient)
192+
if err != nil {
193+
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to create new event handler")
194+
return
195+
}
196+
197+
if err := eventHandler.Start(); err != nil {
198+
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to start event handler")
199+
return
200+
}
201+
<-eventHandler.Done()
202+
}
203+
166204
func (a *APIController) WSHandler(writer http.ResponseWriter, req *http.Request) {
167205
ctx := req.Context()
168206
if !auth.IsAdmin(ctx) {
@@ -183,14 +221,9 @@ func (a *APIController) WSHandler(writer http.ResponseWriter, req *http.Request)
183221
slog.With(slog.Any("error", err)).ErrorContext(ctx, "error upgrading to websockets")
184222
return
185223
}
224+
defer conn.Close()
186225

187-
// nolint:golangci-lint,godox
188-
// TODO (gsamfira): Handle ExpiresAt. Right now, if a client uses
189-
// a valid token to authenticate, and keeps the websocket connection
190-
// open, it will allow that client to stream logs via websockets
191-
// until the connection is broken. We need to forcefully disconnect
192-
// the client once the token expires.
193-
client, err := wsWriter.NewClient(conn, a.hub)
226+
client, err := wsWriter.NewClient(ctx, conn)
194227
if err != nil {
195228
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to create new client")
196229
return
@@ -199,7 +232,14 @@ func (a *APIController) WSHandler(writer http.ResponseWriter, req *http.Request)
199232
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to register new client")
200233
return
201234
}
202-
client.Go()
235+
defer a.hub.Unregister(client)
236+
237+
if err := client.Start(); err != nil {
238+
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to start client")
239+
return
240+
}
241+
<-client.Done()
242+
slog.Info("client disconnected", "client_id", client.ID())
203243
}
204244

205245
// NotFoundHandler is returned when an invalid URL is acccessed

apiserver/events/events.go

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
package events
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"fmt"
8+
"log/slog"
9+
"sync"
10+
11+
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
12+
commonUtil "github.com/cloudbase/garm-provider-common/util"
13+
"github.com/cloudbase/garm/auth"
14+
"github.com/cloudbase/garm/database/common"
15+
"github.com/cloudbase/garm/database/watcher"
16+
"github.com/cloudbase/garm/websocket"
17+
)
18+
19+
func NewHandler(ctx context.Context, client *websocket.Client) (*EventHandler, error) {
20+
if client == nil {
21+
return nil, runnerErrors.ErrUnauthorized
22+
}
23+
24+
newID := commonUtil.NewID()
25+
userID := auth.UserID(ctx)
26+
if userID == "" {
27+
return nil, runnerErrors.ErrUnauthorized
28+
}
29+
consumerID := fmt.Sprintf("ws-event-watcher-%s-%s", userID, newID)
30+
consumer, err := watcher.RegisterConsumer(
31+
// Filter everything by default. Users should set up filters
32+
// after registration.
33+
ctx, consumerID, watcher.WithNone())
34+
if err != nil {
35+
return nil, err
36+
}
37+
38+
handler := &EventHandler{
39+
client: client,
40+
ctx: ctx,
41+
consumer: consumer,
42+
done: make(chan struct{}),
43+
}
44+
client.SetMessageHandler(handler.HandleClientMessages)
45+
46+
return handler, nil
47+
}
48+
49+
type EventHandler struct {
50+
client *websocket.Client
51+
consumer common.Consumer
52+
53+
ctx context.Context
54+
done chan struct{}
55+
running bool
56+
57+
mux sync.Mutex
58+
}
59+
60+
func (e *EventHandler) loop() {
61+
defer e.Stop()
62+
63+
for {
64+
select {
65+
case <-e.ctx.Done():
66+
slog.DebugContext(e.ctx, "context done, stopping event handler")
67+
return
68+
case <-e.client.Done():
69+
slog.DebugContext(e.ctx, "client done, stopping event handler")
70+
return
71+
case <-e.Done():
72+
slog.DebugContext(e.ctx, "done channel closed, stopping event handler")
73+
case event, ok := <-e.consumer.Watch():
74+
if !ok {
75+
slog.DebugContext(e.ctx, "watcher closed, stopping event handler")
76+
return
77+
}
78+
asJs, err := json.Marshal(event)
79+
if err != nil {
80+
slog.ErrorContext(e.ctx, "failed to marshal event", "error", err)
81+
continue
82+
}
83+
if _, err := e.client.Write(asJs); err != nil {
84+
slog.ErrorContext(e.ctx, "failed to write event", "error", err)
85+
}
86+
}
87+
}
88+
}
89+
90+
func (e *EventHandler) Start() error {
91+
e.mux.Lock()
92+
defer e.mux.Unlock()
93+
94+
if e.running {
95+
return nil
96+
}
97+
98+
if err := e.client.Start(); err != nil {
99+
return err
100+
}
101+
e.running = true
102+
go e.loop()
103+
return nil
104+
}
105+
106+
func (e *EventHandler) Stop() {
107+
e.mux.Lock()
108+
defer e.mux.Unlock()
109+
110+
if !e.running {
111+
return
112+
}
113+
e.running = false
114+
e.consumer.Close()
115+
e.client.Stop()
116+
close(e.done)
117+
}
118+
119+
func (e *EventHandler) Done() <-chan struct{} {
120+
return e.done
121+
}
122+
123+
// optionsToWatcherFilters converts the Options struct to a PayloadFilterFunc.
124+
// The client will send an array of filters that indicates which entities and which
125+
// operations the client is interested in. The behavior is that of "any" filter.
126+
// Which means that if any of the elements in the array match an event, it will be
127+
// sent to the websocket.
128+
// Alternatively, clients can choose to get everything.
129+
func (e *EventHandler) optionsToWatcherFilters(opt Options) common.PayloadFilterFunc {
130+
if opt.SendEverything {
131+
return watcher.WithEverything()
132+
}
133+
134+
var funcs []common.PayloadFilterFunc
135+
for _, filter := range opt.Filters {
136+
var filterFunc []common.PayloadFilterFunc
137+
if filter.EntityType == "" {
138+
return watcher.WithNone()
139+
}
140+
filterFunc = append(filterFunc, watcher.WithEntityTypeFilter(filter.EntityType))
141+
if len(filter.Operations) > 0 {
142+
var opFunc []common.PayloadFilterFunc
143+
for _, op := range filter.Operations {
144+
opFunc = append(opFunc, watcher.WithOperationTypeFilter(op))
145+
}
146+
filterFunc = append(filterFunc, watcher.WithAny(opFunc...))
147+
}
148+
funcs = append(funcs, watcher.WithAll(filterFunc...))
149+
}
150+
return watcher.WithAny(funcs...)
151+
}
152+
153+
func (e *EventHandler) HandleClientMessages(message []byte) error {
154+
if e.consumer == nil {
155+
return fmt.Errorf("consumer not initialized")
156+
}
157+
158+
var opt Options
159+
if err := json.Unmarshal(message, &opt); err != nil {
160+
slog.ErrorContext(e.ctx, "failed to unmarshal message from client", "error", err, "message", string(message))
161+
// Client is in error. Disconnect.
162+
e.client.Write([]byte("failed to unmarshal filter"))
163+
e.Stop()
164+
return nil
165+
}
166+
167+
if err := opt.Validate(); err != nil {
168+
if errors.Is(err, common.ErrNoFiltersProvided) {
169+
slog.DebugContext(e.ctx, "no filters provided; ignoring")
170+
return nil
171+
}
172+
slog.ErrorContext(e.ctx, "invalid filter", "error", err)
173+
e.client.Write([]byte("invalid filter"))
174+
e.Stop()
175+
return nil
176+
}
177+
178+
watcherFilters := e.optionsToWatcherFilters(opt)
179+
e.consumer.SetFilters(watcherFilters)
180+
return nil
181+
}

apiserver/events/params.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package events
2+
3+
import (
4+
"github.com/cloudbase/garm/database/common"
5+
)
6+
7+
type Filter struct {
8+
Operations []common.OperationType `json:"operations"`
9+
EntityType common.DatabaseEntityType `json:"entity_type"`
10+
}
11+
12+
func (f Filter) Validate() error {
13+
switch f.EntityType {
14+
case common.RepositoryEntityType, common.OrganizationEntityType, common.EnterpriseEntityType,
15+
common.PoolEntityType, common.UserEntityType, common.InstanceEntityType,
16+
common.JobEntityType, common.ControllerEntityType, common.GithubCredentialsEntityType,
17+
common.GithubEndpointEntityType:
18+
default:
19+
return common.ErrInvalidEntityType
20+
}
21+
22+
for _, op := range f.Operations {
23+
switch op {
24+
case common.CreateOperation, common.UpdateOperation, common.DeleteOperation:
25+
default:
26+
return common.ErrInvalidOperationType
27+
}
28+
}
29+
return nil
30+
}
31+
32+
type Options struct {
33+
SendEverything bool `json:"send_everything"`
34+
Filters []Filter `json:"filters"`
35+
}
36+
37+
func (o Options) Validate() error {
38+
if o.SendEverything {
39+
return nil
40+
}
41+
if len(o.Filters) == 0 {
42+
return common.ErrNoFiltersProvided
43+
}
44+
for _, f := range o.Filters {
45+
if err := f.Validate(); err != nil {
46+
return err
47+
}
48+
}
49+
return nil
50+
}

apiserver/routers/routers.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ func NewAPIRouter(han *controllers.APIController, authMiddleware, initMiddleware
413413

414414
// Websocket log writer
415415
apiRouter.Handle("/{ws:ws\\/?}", http.HandlerFunc(han.WSHandler)).Methods("GET")
416+
apiRouter.Handle("/{events:events\\/?}", http.HandlerFunc(han.EventsHandler)).Methods("GET")
416417

417418
// NotFound handler
418419
apiRouter.PathPrefix("/").HandlerFunc(han.NotFoundHandler).Methods("GET", "POST", "PUT", "DELETE", "OPTIONS")

auth/auth.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,19 @@ func (a *Authenticator) GetJWTToken(ctx context.Context) (string, error) {
5555
expires := &jwt.NumericDate{
5656
Time: expireToken,
5757
}
58+
generation := PasswordGeneration(ctx)
5859
claims := JWTClaims{
5960
RegisteredClaims: jwt.RegisteredClaims{
6061
ExpiresAt: expires,
6162
// nolint:golangci-lint,godox
6263
// TODO: make this configurable
6364
Issuer: "garm",
6465
},
65-
UserID: UserID(ctx),
66-
TokenID: tokenID,
67-
IsAdmin: IsAdmin(ctx),
68-
FullName: FullName(ctx),
66+
UserID: UserID(ctx),
67+
TokenID: tokenID,
68+
IsAdmin: IsAdmin(ctx),
69+
FullName: FullName(ctx),
70+
Generation: generation,
6971
}
7072
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
7173
tokenString, err := token.SignedString([]byte(a.cfg.Secret))
@@ -182,5 +184,5 @@ func (a *Authenticator) AuthenticateUser(ctx context.Context, info params.Passwo
182184
return ctx, runnerErrors.ErrUnauthorized
183185
}
184186

185-
return PopulateContext(ctx, user), nil
187+
return PopulateContext(ctx, user, nil), nil
186188
}

0 commit comments

Comments
 (0)