Skip to content

Commit

Permalink
Blockwise: Use hash instead of token for cache
Browse files Browse the repository at this point in the history
Credit to @zworks-okada for initial work regarding rx transfers. Expanded to include tx.

Closes plgd-dev#512
  • Loading branch information
mpenate-ellenbytech committed May 15, 2024
1 parent 2a95bd6 commit bb2fd9b
Show file tree
Hide file tree
Showing 11 changed files with 105 additions and 51 deletions.
4 changes: 2 additions & 2 deletions dtls/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ func Client(conn *dtls.Conn, opts ...udp.Option) *udpClient.Conn {
v,
cfg.BlockwiseTransferTimeout,
cfg.Errors,
func(token message.Token) (*pool.Message, bool) {
return v.GetObservationRequest(token)
func(hash uint64) (*pool.Message, bool) {
return v.GetObservationRequest(hash)
},
)
}
Expand Down
4 changes: 2 additions & 2 deletions dtls/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ func (s *Server) createConn(connection *coapNet.Conn, inactivityMonitor udpClien
v,
s.cfg.BlockwiseTransferTimeout,
s.cfg.Errors,
func(token message.Token) (*pool.Message, bool) {
return v.GetObservationRequest(token)
func(hash uint64) (*pool.Message, bool) {
return v.GetObservationRequest(hash)
},
)
}
Expand Down
4 changes: 4 additions & 0 deletions message/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ type OptionID uint16
| 35 | x | x | - | | Proxy-Uri | string | 1-1034 | (none) |
| 39 | x | x | - | | Proxy-Scheme | string | 1-255 | (none) |
| 60 | | | x | | Size1 | uint | 0-4 | (none) |
| 292 | | | | x | Request-Tag | opaque | 0-8 | (none) |
+-----+----+---+---+---+----------------+--------+--------+---------+
C=Critical, U=Unsafe, N=NoCacheKey, R=Repeatable
*/
Expand All @@ -73,6 +74,7 @@ const (
ProxyScheme OptionID = 39
Size1 OptionID = 60
NoResponse OptionID = 258
RequestTag OptionID = 292
)

var optionIDToString = map[OptionID]string{
Expand All @@ -96,6 +98,7 @@ var optionIDToString = map[OptionID]string{
ProxyScheme: "ProxyScheme",
Size1: "Size1",
NoResponse: "NoResponse",
RequestTag: "RequestTag",
}

func (o OptionID) String() string {
Expand Down Expand Up @@ -153,6 +156,7 @@ var CoapOptionDefs = map[OptionID]OptionDef{
ProxyScheme: {ValueFormat: ValueString, MinLen: 1, MaxLen: 255},
Size1: {ValueFormat: ValueUint, MinLen: 0, MaxLen: 4},
NoResponse: {ValueFormat: ValueUint, MinLen: 0, MaxLen: 1},
RequestTag: {ValueFormat: ValueOpaque, MinLen: 0, MaxLen: 8},
}

// MediaType specifies the content format of a message.
Expand Down
113 changes: 80 additions & 33 deletions net/blockwise/blockwise.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (
"context"
"errors"
"fmt"
"hash/crc64"
"io"
"net"
"time"

"github.com/dsnet/golib/memfile"
Expand Down Expand Up @@ -131,14 +133,17 @@ type Client interface {
AcquireMessage(ctx context.Context) *pool.Message
// return back the message to the pool for next use
ReleaseMessage(m *pool.Message)

// The remote address for determining the endpoint pair
RemoteAddr() net.Addr
}

type BlockWise[C Client] struct {
cc C
receivingMessagesCache *cache.Cache[uint64, *messageGuard]
sendingMessagesCache *cache.Cache[uint64, *pool.Message]
errors func(error)
getSentRequestFromOutside func(token message.Token) (*pool.Message, bool)
getSentRequestFromOutside func(hash uint64) (*pool.Message, bool)
expiration time.Duration
}

Expand All @@ -160,10 +165,10 @@ func New[C Client](
cc C,
expiration time.Duration,
errors func(error),
getSentRequestFromOutside func(token message.Token) (*pool.Message, bool),
getSentRequestFromOutside func(hash uint64) (*pool.Message, bool),
) *BlockWise[C] {
if getSentRequestFromOutside == nil {
getSentRequestFromOutside = func(message.Token) (*pool.Message, bool) { return nil, false }
getSentRequestFromOutside = func(uint64) (*pool.Message, bool) { return nil, false }
}
return &BlockWise[C]{
cc: cc,
Expand Down Expand Up @@ -214,11 +219,12 @@ func (b *BlockWise[C]) Do(r *pool.Message, maxSzx SZX, maxMessageSize uint32, do
if !ok {
expire = time.Now().Add(b.expiration)
}
_, loaded := b.sendingMessagesCache.LoadOrStore(r.Token().Hash(), cache.NewElement(r, expire, nil))
matchableHash := generateMatchableHash(r.Options(), b.cc.RemoteAddr(), r.Code())
_, loaded := b.sendingMessagesCache.LoadOrStore(matchableHash, cache.NewElement(r, expire, nil))
if loaded {
return nil, errors.New("invalid token")
}
defer b.sendingMessagesCache.Delete(r.Token().Hash())
defer b.sendingMessagesCache.Delete(matchableHash)
if r.Body() == nil {
return do(r)
}
Expand Down Expand Up @@ -282,9 +288,9 @@ func (b *BlockWise[C]) WriteMessage(request *pool.Message, maxSZX SZX, maxMessag
if err != nil {
return fmt.Errorf("cannot encode start sending message block option(%v,%v,%v): %w", maxSZX, 0, true, err)
}

matchableHash := generateMatchableHash(request.Options(), b.cc.RemoteAddr(), request.Code())
w := newWriteRequestResponse(b.cc, request)
err = b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock)
err = b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock, matchableHash)
if err != nil {
return fmt.Errorf("cannot start writing request: %w", err)
}
Expand Down Expand Up @@ -333,8 +339,8 @@ func wantsToBeReceived(r *pool.Message) bool {
return true
}

func (b *BlockWise[C]) getSendingMessageCode(token uint64) (codes.Code, bool) {
v := b.sendingMessagesCache.Load(token)
func (b *BlockWise[C]) getSendingMessageCode(hash uint64) (codes.Code, bool) {
v := b.sendingMessagesCache.Load(hash)
if v == nil {
return codes.Empty, false
}
Expand All @@ -348,19 +354,20 @@ func (b *BlockWise[C]) Handle(w *responsewriter.ResponseWriter[C], r *pool.Messa
}
token := r.Token()

matchableHash := generateMatchableHash(r.Options(), w.Conn().RemoteAddr(), r.Code())

if len(token) == 0 {
err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, next)
err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, next, matchableHash)
if err != nil {
b.sendEntityIncomplete(w, token)
b.errors(fmt.Errorf("handleReceivedMessage(%v): %w", r, err))
}
return
}
tokenStr := token.Hash()

sendingMessageCode, sendingMessageExist := b.getSendingMessageCode(tokenStr)
sendingMessageCode, sendingMessageExist := b.getSendingMessageCode(matchableHash)
if !sendingMessageExist || wantsToBeReceived(r) {
err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, next)
err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, next, matchableHash)
if err != nil {
b.sendEntityIncomplete(w, token)
b.errors(fmt.Errorf("handleReceivedMessage(%v): %w", r, err))
Expand All @@ -369,17 +376,17 @@ func (b *BlockWise[C]) Handle(w *responsewriter.ResponseWriter[C], r *pool.Messa
}
more, err := b.continueSendingMessage(w, r, maxSZX, maxMessageSize, sendingMessageCode)
if err != nil {
b.sendingMessagesCache.Delete(tokenStr)
b.sendingMessagesCache.Delete(matchableHash)
b.errors(fmt.Errorf("continueSendingMessage(%v): %w", r, err))
return
}
// For codes GET,POST,PUT,DELETE, we want them to wait for pairing response and then delete them when the full response comes in or when timeout occurs.
if !more && sendingMessageCode > codes.DELETE {
b.sendingMessagesCache.Delete(tokenStr)
b.sendingMessagesCache.Delete(matchableHash)
}
}

func (b *BlockWise[C]) handleReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, next func(w *responsewriter.ResponseWriter[C], r *pool.Message)) error {
func (b *BlockWise[C]) handleReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, next func(w *responsewriter.ResponseWriter[C], r *pool.Message), rx_hash uint64) error {
startSendingMessageBlock, err := EncodeBlockOption(maxSZX, 0, true)
if err != nil {
return fmt.Errorf("cannot encode start sending message block option(%v,%v,%v): %w", maxSZX, 0, true, err)
Expand Down Expand Up @@ -411,7 +418,7 @@ func (b *BlockWise[C]) handleReceivedMessage(w *responsewriter.ResponseWriter[C]
return errP
}
}
return b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock)
return b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock, rx_hash)
}

func (b *BlockWise[C]) createSendingMessage(sendingMessage *pool.Message, maxSZX SZX, maxMessageSize uint32, block uint32) (sendMessage *pool.Message, more bool, err error) {
Expand Down Expand Up @@ -504,7 +511,8 @@ func (b *BlockWise[C]) continueSendingMessage(w *responsewriter.ResponseWriter[C
}
var sendMessage *pool.Message
var more bool
b.sendingMessagesCache.LoadWithFunc(r.Token().Hash(), func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] {
matchableHash := generateMatchableHash(r.Options(), w.Conn().RemoteAddr(), r.Code())
b.sendingMessagesCache.LoadWithFunc(matchableHash, func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] {
sendMessage, more, err = b.createSendingMessage(value.Data(), maxSZX, maxMessageSize, block)
if err != nil {
err = fmt.Errorf("cannot create sending message: %w", err)
Expand All @@ -529,7 +537,7 @@ func isObserveResponse(msg *pool.Message) bool {
return msg.Code() >= codes.Created
}

func (b *BlockWise[C]) startSendingMessage(w *responsewriter.ResponseWriter[C], maxSZX SZX, maxMessageSize uint32, block uint32) error {
func (b *BlockWise[C]) startSendingMessage(w *responsewriter.ResponseWriter[C], maxSZX SZX, maxMessageSize uint32, block uint32, rx_hash uint64) error {
payloadSize, err := w.Message().BodySize()
if err != nil {
return payloadSizeError(err)
Expand All @@ -552,16 +560,16 @@ func (b *BlockWise[C]) startSendingMessage(w *responsewriter.ResponseWriter[C],
if !ok {
expire = time.Now().Add(b.expiration)
}
el, loaded := b.sendingMessagesCache.LoadOrStore(sendingMessage.Token().Hash(), cache.NewElement(originalSendingMessage, expire, nil))
el, loaded := b.sendingMessagesCache.LoadOrStore(rx_hash, cache.NewElement(originalSendingMessage, expire, nil))
if loaded {
defer b.cc.ReleaseMessage(originalSendingMessage)
return fmt.Errorf("cannot add message (%v) to sending message cache: message(%v) with token(%v) already exist", originalSendingMessage, el.Data(), sendingMessage.Token())
}
return nil
}

func (b *BlockWise[C]) getSentRequest(token message.Token) *pool.Message {
data, ok := b.sendingMessagesCache.LoadWithFunc(token.Hash(), func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] {
func (b *BlockWise[C]) getSentRequest(hash uint64) *pool.Message {
data, ok := b.sendingMessagesCache.LoadWithFunc(hash, func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] {
if value == nil {
return nil
}
Expand All @@ -576,7 +584,7 @@ func (b *BlockWise[C]) getSentRequest(token message.Token) *pool.Message {
if ok {
return data.Data()
}
globalRequest, ok := b.getSentRequestFromOutside(token)
globalRequest, ok := b.getSentRequestFromOutside(hash)
if ok {
return globalRequest
}
Expand All @@ -595,7 +603,8 @@ func (b *BlockWise[C]) handleObserveResponse(sentRequest *pool.Message) (message
validUntil := time.Now().Add(b.expiration) // context of observation can be expired.
bwSentRequest := b.cloneMessage(sentRequest)
bwSentRequest.SetToken(token)
_, loaded := b.sendingMessagesCache.LoadOrStore(token.Hash(), cache.NewElement(bwSentRequest, validUntil, nil))
matchableHash := generateMatchableHash(sentRequest.Options(), b.cc.RemoteAddr(), sentRequest.Code())
_, loaded := b.sendingMessagesCache.LoadOrStore(matchableHash, cache.NewElement(bwSentRequest, validUntil, nil))
if loaded {
return nil, time.Time{}, errors.New("cannot process message: message with token already exist")
}
Expand Down Expand Up @@ -674,7 +683,7 @@ func copyToPayloadFromOffset(r *pool.Message, payloadFile *memfile.File, offset
return payloadSize, nil
}

func (b *BlockWise[C]) getCachedReceivedMessage(mg *messageGuard, r *pool.Message, tokenStr uint64, validUntil time.Time) (*pool.Message, func(), error) {
func (b *BlockWise[C]) getCachedReceivedMessage(mg *messageGuard, r *pool.Message, tokenStr uint64, matchableHash uint64, validUntil time.Time) (*pool.Message, func(), error) {
cannotLockError := func(err error) error {
return fmt.Errorf("processReceivedMessage: cannot lock message: %w", err)
}
Expand Down Expand Up @@ -708,11 +717,11 @@ func (b *BlockWise[C]) getCachedReceivedMessage(mg *messageGuard, r *pool.Messag
return nil, nil, cannotLockError(errA)
}
appendToClose(mg)
element, loaded := b.receivingMessagesCache.LoadOrStore(tokenStr, cache.NewElement(mg, validUntil, func(d *messageGuard) {
element, loaded := b.receivingMessagesCache.LoadOrStore(matchableHash, cache.NewElement(mg, validUntil, func(d *messageGuard) {
if d == nil {
return
}
b.sendingMessagesCache.Delete(tokenStr)
b.sendingMessagesCache.Delete(matchableHash)
}))
// request was already stored in cache, silently
if loaded {
Expand All @@ -732,6 +741,43 @@ func (b *BlockWise[C]) getCachedReceivedMessage(mg *messageGuard, r *pool.Messag
return mg.Message, closeFn, nil
}

/*
RFC9175 1.1:
Two request messages are said to be "matchable" if they occur between
the same endpoint pair, have the same code, and have the same set of
options, with the exception that elective NoCacheKey options and
options involved in block-wise transfer (Block1, Block2, and Request-
Tag) need not be the same. Two blockwise request operations are said
to be matchable if their request messages are matchable.
This function concatenates the IDs and values of relevant options, the string representation of the remote address,
and the code of the message to generate a hash that can be used to match requests.
*/
func generateMatchableHash(options message.Options, remoteAddr net.Addr, code codes.Code) uint64 {
options_str := ""

input := make([]byte, 0, 512)

for _, opt := range options {
options_str += opt.ID.String() + ","
switch opt.ID {
// Skip Blockwise Options and NoCacheKey Options
case message.Block1, message.Block2, message.Size1, message.Size2, message.RequestTag:
continue
}
input = append(input, byte(opt.ID))
input = append(input, opt.Value...)
}

input = append(input, []byte(remoteAddr.Network())...)
input = append(input, []byte(remoteAddr.String())...)
input = append(input, byte(code))

hash := crc64.Checksum(input, crc64.MakeTable(crc64.ISO))
fmt.Printf("%s %v %v %v\r\n", options_str, remoteAddr, code, hash)
return crc64.Checksum(input, crc64.MakeTable(crc64.ISO))
}

//nolint:gocyclo,gocognit
func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSzx SZX, next func(w *responsewriter.ResponseWriter[C], r *pool.Message), blockType message.OptionID, sizeType message.OptionID) error {
token := r.Token()
Expand All @@ -755,7 +801,8 @@ func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C
if err != nil {
return fmt.Errorf("cannot decode block option: %w", err)
}
sentRequest := b.getSentRequest(token)
matchableHash := generateMatchableHash(r.Options(), w.Conn().RemoteAddr(), r.Code())
sentRequest := b.getSentRequest(matchableHash)
if sentRequest != nil {
defer b.cc.ReleaseMessage(sentRequest)
}
Expand All @@ -772,7 +819,7 @@ func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C

tokenStr := token.Hash()
var cachedReceivedMessageGuard *messageGuard
if e := b.receivingMessagesCache.Load(tokenStr); e != nil {
if e := b.receivingMessagesCache.Load(matchableHash); e != nil {
cachedReceivedMessageGuard = e.Data()
}
if cachedReceivedMessageGuard == nil {
Expand All @@ -783,15 +830,15 @@ func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C
return nil
}
}
cachedReceivedMessage, closeCachedReceivedMessage, err := b.getCachedReceivedMessage(cachedReceivedMessageGuard, r, tokenStr, validUntil)
cachedReceivedMessage, closeCachedReceivedMessage, err := b.getCachedReceivedMessage(cachedReceivedMessageGuard, r, tokenStr, matchableHash, validUntil)
if err != nil {
return err
}
defer closeCachedReceivedMessage()

defer func(err *error) {
if *err != nil {
b.receivingMessagesCache.Delete(tokenStr)
b.receivingMessagesCache.Delete(matchableHash)
}
}(&err)
payloadFile, payloadSize, err := b.getPayloadFromCachedReceivedMessage(r, cachedReceivedMessage)
Expand All @@ -805,12 +852,12 @@ func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C
return fmt.Errorf("cannot copy data to payload: %w", err)
}
if !more {
b.receivingMessagesCache.Delete(tokenStr)
b.receivingMessagesCache.Delete(matchableHash)
cachedReceivedMessage.Remove(blockType)
cachedReceivedMessage.Remove(sizeType)
cachedReceivedMessage.SetType(r.Type())
if !bytes.Equal(cachedReceivedMessage.Token(), token) {
b.sendingMessagesCache.Delete(tokenStr)
b.sendingMessagesCache.Delete(matchableHash)
}
_, errS := cachedReceivedMessage.Body().Seek(0, io.SeekStart)
if errS != nil {
Expand Down
5 changes: 5 additions & 0 deletions net/blockwise/blockwise_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"io"
"net"
"testing"
"time"

Expand Down Expand Up @@ -54,6 +55,10 @@ type testClient struct {
p *pool.Pool
}

func (c *testClient) RemoteAddr() net.Addr {
return &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}
}

func newTestClient() *testClient {
return &testClient{
p: pool.New(100, 1024),
Expand Down
4 changes: 2 additions & 2 deletions net/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ func (c *Client[C]) Observe(ctx context.Context, path string, observeFunc func(r
return c.DoObserve(req, observeFunc)
}

func (c *Client[C]) GetObservationRequest(token message.Token) (*pool.Message, bool) {
return c.observationHandler.GetObservationRequest(token)
func (c *Client[C]) GetObservationRequest(hash uint64) (*pool.Message, bool) {
return c.observationHandler.GetObservationRequest(hash)
}

// NewPostRequest creates post request.
Expand Down
Loading

0 comments on commit bb2fd9b

Please sign in to comment.