From 639c2106619ac3f9c2d7d2faaad152e03f84652d Mon Sep 17 00:00:00 2001 From: mattn Date: Wed, 17 May 2023 19:54:56 +0900 Subject: [PATCH] support NIP-45 (#58) --- handlers.go | 64 ++++++++++++ interface.go | 4 + start_test.go | 10 +- storage/postgresql/query.go | 36 +++++-- storage/postgresql/query_test.go | 161 ++++++++++++++++++++++++++++++- storage/sqlite3/query.go | 105 +++++++++++++------- util_test.go | 8 ++ 7 files changed, 340 insertions(+), 48 deletions(-) diff --git a/handlers.go b/handlers.go index fff3362..8fddd8c 100644 --- a/handlers.go +++ b/handlers.go @@ -179,6 +179,65 @@ func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) { ok, message := AddEvent(ctx, s.relay, evt) ws.WriteJSON([]interface{}{"OK", evt.ID, ok, message}) + case "COUNT": + counter, ok := store.(EventCounter) + if !ok { + notice = "restricted: this relay does not support NIP-45" + return + } + + var id string + json.Unmarshal(request[1], &id) + if id == "" { + notice = "COUNT has no " + return + } + + total := int64(0) + filters := make(nostr.Filters, len(request)-2) + for i, filterReq := range request[2:] { + if err := json.Unmarshal(filterReq, &filters[i]); err != nil { + notice = "failed to decode filter" + return + } + + filter := &filters[i] + + // prevent kind-4 events from being returned to unauthed users, + // only when authentication is a thing + if _, ok := s.relay.(Auther); ok { + if slices.Contains(filter.Kinds, 4) { + senders := filter.Authors + receivers, _ := filter.Tags["p"] + switch { + case ws.authed == "": + // not authenticated + notice = "restricted: this relay does not serve kind-4 to unauthenticated users, does your client implement NIP-42?" + return + case len(senders) == 1 && len(receivers) < 2 && (senders[0] == ws.authed): + // allowed filter: ws.authed is sole sender (filter specifies one or all receivers) + case len(receivers) == 1 && len(senders) < 2 && (receivers[0] == ws.authed): + // allowed filter: ws.authed is sole receiver (filter specifies one or all senders) + default: + // restricted filter: do not return any events, + // even if other elements in filters array were not restricted). + // client should know better. + notice = "restricted: authenticated user does not have authorization for requested filters." + return + } + } + } + + count, err := counter.CountEvents(ctx, filter) + if err != nil { + s.Log.Errorf("store: %v", err) + continue + } + total += count + } + + ws.WriteJSON([]interface{}{"COUNT", id, map[string]int64{"count": total}}) + setListener(id, ws, filters) case "REQ": var id string json.Unmarshal(request[1], &id) @@ -312,6 +371,11 @@ func (s *Server) HandleNIP11(w http.ResponseWriter, r *http.Request) { if _, ok := s.relay.(Auther); ok { supportedNIPs = append(supportedNIPs, 42) } + if storage, ok := s.relay.(Storage); ok && storage != nil { + if _, ok = storage.(EventCounter); ok { + supportedNIPs = append(supportedNIPs, 45) + } + } info := nip11.RelayInformationDocument{ Name: s.relay.Name(), diff --git a/interface.go b/interface.go index dbe65fd..4ae6ceb 100644 --- a/interface.go +++ b/interface.go @@ -90,3 +90,7 @@ type AdvancedSaver interface { BeforeSave(context.Context, *nostr.Event) AfterSave(*nostr.Event) } + +type EventCounter interface { + CountEvents(ctx context.Context, filter *nostr.Filter) (int64, error) +} diff --git a/start_test.go b/start_test.go index 3a23f18..ca4f08d 100644 --- a/start_test.go +++ b/start_test.go @@ -2,11 +2,12 @@ package relayer import ( "context" + "errors" "net/http" "testing" "time" - "github.com/gorilla/websocket" + "github.com/gobwas/ws/wsutil" "github.com/nbd-wtf/go-nostr" ) @@ -83,7 +84,10 @@ func TestServerShutdownWebsocket(t *testing.T) { // wait for the client to receive a "connection close" time.Sleep(1 * time.Second) err = client.ConnectionError - if _, ok := err.(*websocket.CloseError); !ok { - t.Errorf("client.ConnextionError: %v (%T); want websocket.CloseError", err, err) + if e := errors.Unwrap(err); e != nil { + err = e + } + if _, ok := err.(wsutil.ClosedError); !ok { + t.Errorf("client.ConnextionError: %v (%T); want wsutil.ClosedError", err, err) } } diff --git a/storage/postgresql/query.go b/storage/postgresql/query.go index ff3e508..4798076 100644 --- a/storage/postgresql/query.go +++ b/storage/postgresql/query.go @@ -15,7 +15,7 @@ import ( func (b PostgresBackend) QueryEvents(ctx context.Context, filter *nostr.Filter) (ch chan *nostr.Event, err error) { ch = make(chan *nostr.Event) - query, params, err := queryEventsSql(filter) + query, params, err := queryEventsSql(filter, false) if err != nil { return nil, err } @@ -44,7 +44,20 @@ func (b PostgresBackend) QueryEvents(ctx context.Context, filter *nostr.Filter) return ch, nil } -func queryEventsSql(filter *nostr.Filter) (string, []any, error) { +func (b PostgresBackend) CountEvents(ctx context.Context, filter *nostr.Filter) (int64, error) { + query, params, err := queryEventsSql(filter, true) + if err != nil { + return 0, err + } + + var count int64 + if err = b.DB.QueryRow(query, params...).Scan(&count); err != nil && err != sql.ErrNoRows { + return 0, fmt.Errorf("failed to fetch events using query %q: %w", query, err) + } + return count, nil +} + +func queryEventsSql(filter *nostr.Filter, doCount bool) (string, []any, error) { var conditions []string var params []any @@ -165,11 +178,20 @@ func queryEventsSql(filter *nostr.Filter) (string, []any, error) { params = append(params, filter.Limit) } - query := sqlx.Rebind(sqlx.BindType("postgres"), `SELECT - id, pubkey, created_at, kind, tags, content, sig - FROM event WHERE `+ - strings.Join(conditions, " AND ")+ - " ORDER BY created_at DESC LIMIT ?") + var query string + if doCount { + query = sqlx.Rebind(sqlx.BindType("postgres"), `SELECT + COUNT(*) + FROM event WHERE `+ + strings.Join(conditions, " AND ")+ + " ORDER BY created_at DESC LIMIT ?") + } else { + query = sqlx.Rebind(sqlx.BindType("postgres"), `SELECT + id, pubkey, created_at, kind, tags, content, sig + FROM event WHERE `+ + strings.Join(conditions, " AND ")+ + " ORDER BY created_at DESC LIMIT ?") + } return query, params, nil } diff --git a/storage/postgresql/query_test.go b/storage/postgresql/query_test.go index 310374a..226acbd 100644 --- a/storage/postgresql/query_test.go +++ b/storage/postgresql/query_test.go @@ -157,7 +157,7 @@ func TestQueryEventsSql(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - query, params, err := queryEventsSql(tt.filter) + query, params, err := queryEventsSql(tt.filter, false) assert.Equal(t, tt.err, err) if err != nil { return @@ -188,3 +188,162 @@ func strSlice(n int) []string { } return slice } + +func TestCountEventsSql(t *testing.T) { + var tests = []struct { + name string + filter *nostr.Filter + query string + params []any + err error + }{ + { + name: "empty filter", + filter: &nostr.Filter{}, + query: "SELECT COUNT(*) FROM event WHERE true ORDER BY created_at DESC LIMIT $1", + params: []any{100}, + err: nil, + }, + { + name: "ids filter", + filter: &nostr.Filter{ + IDs: []string{"083ec57f36a7b39ab98a57bedab4f85355b2ee89e4b205bed58d7c3ef9edd294"}, + }, + query: `SELECT COUNT(*) + FROM event + WHERE (id LIKE '083ec57f36a7b39ab98a57bedab4f85355b2ee89e4b205bed58d7c3ef9edd294%') + ORDER BY created_at DESC LIMIT $1`, + params: []any{100}, + err: nil, + }, + { + name: "kind filter", + filter: &nostr.Filter{ + Kinds: []int{1, 2, 3}, + }, + query: `SELECT COUNT(*) + FROM event + WHERE kind IN(1,2,3) + ORDER BY created_at DESC LIMIT $1`, + params: []any{100}, + err: nil, + }, + { + name: "authors filter", + filter: &nostr.Filter{ + Authors: []string{"7bdef7bdebb8721f77927d0e77c66059360fa62371fdf15f3add93923a613229"}, + }, + query: `SELECT COUNT(*) + FROM event + WHERE (pubkey LIKE '7bdef7bdebb8721f77927d0e77c66059360fa62371fdf15f3add93923a613229%') + ORDER BY created_at DESC LIMIT $1`, + params: []any{100}, + err: nil, + }, + // errors + { + name: "nil filter", + filter: nil, + query: "", + params: nil, + err: fmt.Errorf("filter cannot be null"), + }, + { + name: "too many ids", + filter: &nostr.Filter{ + IDs: strSlice(501), + }, + query: "", + params: nil, + // REVIEW: should return error + err: nil, + }, + { + name: "invalid ids", + filter: &nostr.Filter{ + IDs: []string{"stuff"}, + }, + query: "", + params: nil, + // REVIEW: should return error + err: nil, + }, + { + name: "too many authors", + filter: &nostr.Filter{ + Authors: strSlice(501), + }, + query: "", + params: nil, + // REVIEW: should return error + err: nil, + }, + { + name: "invalid authors", + filter: &nostr.Filter{ + Authors: []string{"stuff"}, + }, + query: "", + params: nil, + // REVIEW: should return error + err: nil, + }, + { + name: "too many kinds", + filter: &nostr.Filter{ + Kinds: intSlice(11), + }, + query: "", + params: nil, + // REVIEW: should return error + err: nil, + }, + { + name: "no kinds", + filter: &nostr.Filter{ + Kinds: []int{}, + }, + query: "", + params: nil, + // REVIEW: should return error + err: nil, + }, + { + name: "tags of empty array", + filter: &nostr.Filter{ + Tags: nostr.TagMap{ + "#e": []string{}, + }, + }, + query: "", + params: nil, + // REVIEW: should return error + err: nil, + }, + { + name: "too many tag values", + filter: &nostr.Filter{ + Tags: nostr.TagMap{ + "#e": strSlice(11), + }, + }, + query: "", + params: nil, + // REVIEW: should return error + err: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + query, params, err := queryEventsSql(tt.filter, true) + assert.Equal(t, tt.err, err) + if err != nil { + return + } + + assert.Equal(t, clean(tt.query), clean(query)) + assert.Equal(t, tt.params, params) + }) + } +} diff --git a/storage/sqlite3/query.go b/storage/sqlite3/query.go index 1a84a35..67031c9 100644 --- a/storage/sqlite3/query.go +++ b/storage/sqlite3/query.go @@ -4,29 +4,72 @@ import ( "context" "database/sql" "encoding/hex" - "errors" "fmt" "strconv" "strings" + "github.com/jmoiron/sqlx" "github.com/nbd-wtf/go-nostr" ) func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) (ch chan *nostr.Event, err error) { ch = make(chan *nostr.Event) + query, params, err := queryEventsSql(filter, false) + if err != nil { + return nil, err + } + + rows, err := b.DB.Query(query, params...) + if err != nil && err != sql.ErrNoRows { + return nil, fmt.Errorf("failed to fetch events using query %q: %w", query, err) + } + + go func() { + defer rows.Close() + defer close(ch) + for rows.Next() { + var evt nostr.Event + var timestamp int64 + err := rows.Scan(&evt.ID, &evt.PubKey, ×tamp, + &evt.Kind, &evt.Tags, &evt.Content, &evt.Sig) + if err != nil { + return + } + evt.CreatedAt = nostr.Timestamp(timestamp) + ch <- &evt + } + }() + + return ch, nil +} + +func (b SQLite3Backend) CountEvents(ctx context.Context, filter *nostr.Filter) (int64, error) { + query, params, err := queryEventsSql(filter, true) + if err != nil { + return 0, err + } + + var count int64 + err = b.DB.QueryRow(query, params...).Scan(&count) + if err != nil && err != sql.ErrNoRows { + return 0, fmt.Errorf("failed to fetch events using query %q: %w", query, err) + } + return count, nil +} + +func queryEventsSql(filter *nostr.Filter, doCount bool) (string, []any, error) { var conditions []string var params []any if filter == nil { - err = errors.New("filter cannot be null") - return + return "", nil, fmt.Errorf("filter cannot be null") } if filter.IDs != nil { if len(filter.IDs) > 500 { // too many ids, fail everything - return + return "", nil, nil } likeids := make([]string, 0, len(filter.IDs)) @@ -41,7 +84,7 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) ( } if len(likeids) == 0 { // ids being [] mean you won't get anything - return + return "", nil, nil } conditions = append(conditions, "("+strings.Join(likeids, " OR ")+")") } @@ -49,7 +92,7 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) ( if filter.Authors != nil { if len(filter.Authors) > 500 { // too many authors, fail everything - return + return "", nil, nil } likekeys := make([]string, 0, len(filter.Authors)) @@ -64,7 +107,7 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) ( } if len(likekeys) == 0 { // authors being [] mean you won't get anything - return + return "", nil, nil } conditions = append(conditions, "("+strings.Join(likekeys, " OR ")+")") } @@ -72,12 +115,12 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) ( if filter.Kinds != nil { if len(filter.Kinds) > 10 { // too many kinds, fail everything - return + return "", nil, nil } if len(filter.Kinds) == 0 { // kinds being [] mean you won't get anything - return + return "", nil, nil } // no sql injection issues since these are ints inkinds := make([]string, len(filter.Kinds)) @@ -91,7 +134,7 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) ( for _, values := range filter.Tags { if len(values) == 0 { // any tag set to [] is wrong - return + return "", nil, nil } // add these tags to the query @@ -99,7 +142,7 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) ( if len(tagQuery) > 10 { // too many tags, fail everything - return + return "", nil, nil } } @@ -134,32 +177,20 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) ( params = append(params, filter.Limit) } - query := b.DB.Rebind(`SELECT - id, pubkey, created_at, kind, tags, content, sig - FROM event WHERE ` + - strings.Join(conditions, " AND ") + - " ORDER BY created_at DESC LIMIT ?") - - rows, err := b.DB.Query(query, params...) - if err != nil && err != sql.ErrNoRows { - return nil, fmt.Errorf("failed to fetch events using query %q: %w", query, err) + var query string + if doCount { + query = sqlx.Rebind(sqlx.BindType("sqlite3"), `SELECT + COUNT(*) + FROM event WHERE `+ + strings.Join(conditions, " AND ")+ + " ORDER BY created_at DESC LIMIT ?") + } else { + query = sqlx.Rebind(sqlx.BindType("sqlite3"), `SELECT + id, pubkey, created_at, kind, tags, content, sig + FROM event WHERE `+ + strings.Join(conditions, " AND ")+ + " ORDER BY created_at DESC LIMIT ?") } - go func() { - defer rows.Close() - defer close(ch) - for rows.Next() { - var evt nostr.Event - var timestamp int64 - err := rows.Scan(&evt.ID, &evt.PubKey, ×tamp, - &evt.Kind, &evt.Tags, &evt.Content, &evt.Sig) - if err != nil { - return - } - evt.CreatedAt = nostr.Timestamp(timestamp) - ch <- &evt - } - }() - - return ch, nil + return query, params, nil } diff --git a/util_test.go b/util_test.go index ec0c385..0a02610 100644 --- a/util_test.go +++ b/util_test.go @@ -52,6 +52,7 @@ type testStorage struct { queryEvents func(context.Context, *nostr.Filter) (chan *nostr.Event, error) deleteEvent func(ctx context.Context, id string, pubkey string) error saveEvent func(context.Context, *nostr.Event) error + countEvents func(context.Context, *nostr.Filter) (int64, error) } func (st *testStorage) Init() error { @@ -81,3 +82,10 @@ func (st *testStorage) SaveEvent(ctx context.Context, e *nostr.Event) error { } return nil } + +func (st *testStorage) CountEvents(ctx context.Context, f *nostr.Filter) (int64, error) { + if fn := st.countEvents; fn != nil { + return fn(ctx, f) + } + return 0, nil +}