Skip to content

Commit

Permalink
support NIP-45 (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattn committed May 17, 2023
1 parent c4a678d commit 639c210
Show file tree
Hide file tree
Showing 7 changed files with 340 additions and 48 deletions.
64 changes: 64 additions & 0 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,65 @@ func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
ok, message := AddEvent(ctx, s.relay, evt)
ws.WriteJSON([]interface{}{"OK", evt.ID, ok, message})

case "COUNT":
counter, ok := store.(EventCounter)
if !ok {
notice = "restricted: this relay does not support NIP-45"
return
}

var id string
json.Unmarshal(request[1], &id)
if id == "" {
notice = "COUNT has no <id>"
return
}

total := int64(0)
filters := make(nostr.Filters, len(request)-2)
for i, filterReq := range request[2:] {
if err := json.Unmarshal(filterReq, &filters[i]); err != nil {
notice = "failed to decode filter"
return
}

filter := &filters[i]

// prevent kind-4 events from being returned to unauthed users,
// only when authentication is a thing
if _, ok := s.relay.(Auther); ok {
if slices.Contains(filter.Kinds, 4) {
senders := filter.Authors
receivers, _ := filter.Tags["p"]
switch {
case ws.authed == "":
// not authenticated
notice = "restricted: this relay does not serve kind-4 to unauthenticated users, does your client implement NIP-42?"
return
case len(senders) == 1 && len(receivers) < 2 && (senders[0] == ws.authed):
// allowed filter: ws.authed is sole sender (filter specifies one or all receivers)
case len(receivers) == 1 && len(senders) < 2 && (receivers[0] == ws.authed):
// allowed filter: ws.authed is sole receiver (filter specifies one or all senders)
default:
// restricted filter: do not return any events,
// even if other elements in filters array were not restricted).
// client should know better.
notice = "restricted: authenticated user does not have authorization for requested filters."
return
}
}
}

count, err := counter.CountEvents(ctx, filter)
if err != nil {
s.Log.Errorf("store: %v", err)
continue
}
total += count
}

ws.WriteJSON([]interface{}{"COUNT", id, map[string]int64{"count": total}})
setListener(id, ws, filters)
case "REQ":
var id string
json.Unmarshal(request[1], &id)
Expand Down Expand Up @@ -312,6 +371,11 @@ func (s *Server) HandleNIP11(w http.ResponseWriter, r *http.Request) {
if _, ok := s.relay.(Auther); ok {
supportedNIPs = append(supportedNIPs, 42)
}
if storage, ok := s.relay.(Storage); ok && storage != nil {
if _, ok = storage.(EventCounter); ok {
supportedNIPs = append(supportedNIPs, 45)
}
}

info := nip11.RelayInformationDocument{
Name: s.relay.Name(),
Expand Down
4 changes: 4 additions & 0 deletions interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,7 @@ type AdvancedSaver interface {
BeforeSave(context.Context, *nostr.Event)
AfterSave(*nostr.Event)
}

type EventCounter interface {
CountEvents(ctx context.Context, filter *nostr.Filter) (int64, error)
}
10 changes: 7 additions & 3 deletions start_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package relayer

import (
"context"
"errors"
"net/http"
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/gobwas/ws/wsutil"
"github.com/nbd-wtf/go-nostr"
)

Expand Down Expand Up @@ -83,7 +84,10 @@ func TestServerShutdownWebsocket(t *testing.T) {
// wait for the client to receive a "connection close"
time.Sleep(1 * time.Second)
err = client.ConnectionError
if _, ok := err.(*websocket.CloseError); !ok {
t.Errorf("client.ConnextionError: %v (%T); want websocket.CloseError", err, err)
if e := errors.Unwrap(err); e != nil {
err = e
}
if _, ok := err.(wsutil.ClosedError); !ok {
t.Errorf("client.ConnextionError: %v (%T); want wsutil.ClosedError", err, err)
}
}
36 changes: 29 additions & 7 deletions storage/postgresql/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
func (b PostgresBackend) QueryEvents(ctx context.Context, filter *nostr.Filter) (ch chan *nostr.Event, err error) {
ch = make(chan *nostr.Event)

query, params, err := queryEventsSql(filter)
query, params, err := queryEventsSql(filter, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -44,7 +44,20 @@ func (b PostgresBackend) QueryEvents(ctx context.Context, filter *nostr.Filter)
return ch, nil
}

func queryEventsSql(filter *nostr.Filter) (string, []any, error) {
func (b PostgresBackend) CountEvents(ctx context.Context, filter *nostr.Filter) (int64, error) {
query, params, err := queryEventsSql(filter, true)
if err != nil {
return 0, err
}

var count int64
if err = b.DB.QueryRow(query, params...).Scan(&count); err != nil && err != sql.ErrNoRows {
return 0, fmt.Errorf("failed to fetch events using query %q: %w", query, err)
}
return count, nil
}

func queryEventsSql(filter *nostr.Filter, doCount bool) (string, []any, error) {
var conditions []string
var params []any

Expand Down Expand Up @@ -165,11 +178,20 @@ func queryEventsSql(filter *nostr.Filter) (string, []any, error) {
params = append(params, filter.Limit)
}

query := sqlx.Rebind(sqlx.BindType("postgres"), `SELECT
id, pubkey, created_at, kind, tags, content, sig
FROM event WHERE `+
strings.Join(conditions, " AND ")+
" ORDER BY created_at DESC LIMIT ?")
var query string
if doCount {
query = sqlx.Rebind(sqlx.BindType("postgres"), `SELECT
COUNT(*)
FROM event WHERE `+
strings.Join(conditions, " AND ")+
" ORDER BY created_at DESC LIMIT ?")
} else {
query = sqlx.Rebind(sqlx.BindType("postgres"), `SELECT
id, pubkey, created_at, kind, tags, content, sig
FROM event WHERE `+
strings.Join(conditions, " AND ")+
" ORDER BY created_at DESC LIMIT ?")
}

return query, params, nil
}
161 changes: 160 additions & 1 deletion storage/postgresql/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func TestQueryEventsSql(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query, params, err := queryEventsSql(tt.filter)
query, params, err := queryEventsSql(tt.filter, false)
assert.Equal(t, tt.err, err)
if err != nil {
return
Expand Down Expand Up @@ -188,3 +188,162 @@ func strSlice(n int) []string {
}
return slice
}

func TestCountEventsSql(t *testing.T) {
var tests = []struct {
name string
filter *nostr.Filter
query string
params []any
err error
}{
{
name: "empty filter",
filter: &nostr.Filter{},
query: "SELECT COUNT(*) FROM event WHERE true ORDER BY created_at DESC LIMIT $1",
params: []any{100},
err: nil,
},
{
name: "ids filter",
filter: &nostr.Filter{
IDs: []string{"083ec57f36a7b39ab98a57bedab4f85355b2ee89e4b205bed58d7c3ef9edd294"},
},
query: `SELECT COUNT(*)
FROM event
WHERE (id LIKE '083ec57f36a7b39ab98a57bedab4f85355b2ee89e4b205bed58d7c3ef9edd294%')
ORDER BY created_at DESC LIMIT $1`,
params: []any{100},
err: nil,
},
{
name: "kind filter",
filter: &nostr.Filter{
Kinds: []int{1, 2, 3},
},
query: `SELECT COUNT(*)
FROM event
WHERE kind IN(1,2,3)
ORDER BY created_at DESC LIMIT $1`,
params: []any{100},
err: nil,
},
{
name: "authors filter",
filter: &nostr.Filter{
Authors: []string{"7bdef7bdebb8721f77927d0e77c66059360fa62371fdf15f3add93923a613229"},
},
query: `SELECT COUNT(*)
FROM event
WHERE (pubkey LIKE '7bdef7bdebb8721f77927d0e77c66059360fa62371fdf15f3add93923a613229%')
ORDER BY created_at DESC LIMIT $1`,
params: []any{100},
err: nil,
},
// errors
{
name: "nil filter",
filter: nil,
query: "",
params: nil,
err: fmt.Errorf("filter cannot be null"),
},
{
name: "too many ids",
filter: &nostr.Filter{
IDs: strSlice(501),
},
query: "",
params: nil,
// REVIEW: should return error
err: nil,
},
{
name: "invalid ids",
filter: &nostr.Filter{
IDs: []string{"stuff"},
},
query: "",
params: nil,
// REVIEW: should return error
err: nil,
},
{
name: "too many authors",
filter: &nostr.Filter{
Authors: strSlice(501),
},
query: "",
params: nil,
// REVIEW: should return error
err: nil,
},
{
name: "invalid authors",
filter: &nostr.Filter{
Authors: []string{"stuff"},
},
query: "",
params: nil,
// REVIEW: should return error
err: nil,
},
{
name: "too many kinds",
filter: &nostr.Filter{
Kinds: intSlice(11),
},
query: "",
params: nil,
// REVIEW: should return error
err: nil,
},
{
name: "no kinds",
filter: &nostr.Filter{
Kinds: []int{},
},
query: "",
params: nil,
// REVIEW: should return error
err: nil,
},
{
name: "tags of empty array",
filter: &nostr.Filter{
Tags: nostr.TagMap{
"#e": []string{},
},
},
query: "",
params: nil,
// REVIEW: should return error
err: nil,
},
{
name: "too many tag values",
filter: &nostr.Filter{
Tags: nostr.TagMap{
"#e": strSlice(11),
},
},
query: "",
params: nil,
// REVIEW: should return error
err: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query, params, err := queryEventsSql(tt.filter, true)
assert.Equal(t, tt.err, err)
if err != nil {
return
}

assert.Equal(t, clean(tt.query), clean(query))
assert.Equal(t, tt.params, params)
})
}
}
Loading

0 comments on commit 639c210

Please sign in to comment.