Skip to content

Commit

Permalink
Clean up the way local promises work.
Browse files Browse the repository at this point in the history
Push the logic for flushing the answerqueue into the Promise type
itself. This is much cleaner, and avoids some racy logic that I'm not
sure was correct.
  • Loading branch information
zenhack committed Jun 22, 2023
1 parent f103d94 commit 51facf7
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 58 deletions.
17 changes: 15 additions & 2 deletions answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type Promise struct {
// - Resolved. Fulfill or Reject has finished.

state mutex.Mutex[promiseState]

resolver Resolver[Ptr]
}

type promiseState struct {
Expand Down Expand Up @@ -64,11 +66,13 @@ type clientAndPromise struct {
}

// NewPromise creates a new unresolved promise. The PipelineCaller will
// be used to make pipelined calls before the promise resolves.
func NewPromise(m Method, pc PipelineCaller) *Promise {
// be used to make pipelined calls before the promise resolves. If resolver
// is not nil, calls to Fulfill will be forwarded to it.
func NewPromise(m Method, pc PipelineCaller, resolver Resolver[Ptr]) *Promise {
if pc == nil {
panic("NewPromise(nil)")
}

resolved := make(chan struct{})
p := &Promise{
method: m,
Expand All @@ -77,6 +81,7 @@ func NewPromise(m Method, pc PipelineCaller) *Promise {
signals: []func(){func() { close(resolved) }},
caller: pc,
}),
resolver: resolver,
}
p.ans.f.promise = p
p.ans.metadata = *NewMetadata()
Expand Down Expand Up @@ -152,6 +157,14 @@ func (p *Promise) Resolve(r Ptr, e error) {
return p.clients
})

if p.resolver != nil {
if e == nil {
p.resolver.Fulfill(r)
} else {
p.resolver.Reject(e)
}
}

// Pending resolution state: wait for clients to be fulfilled
// and calls to have answers.
res := resolution{p.method, r, e}
Expand Down
12 changes: 6 additions & 6 deletions answer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ var dummyMethod = Method{

func TestPromiseReject(t *testing.T) {
t.Run("Done", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
done := p.Answer().Done()
p.Reject(errors.New("omg bbq"))
select {
Expand All @@ -27,7 +27,7 @@ func TestPromiseReject(t *testing.T) {
}
})
t.Run("Struct", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
defer p.ReleaseClients()
ans := p.Answer()
p.Reject(errors.New("omg bbq"))
Expand All @@ -36,7 +36,7 @@ func TestPromiseReject(t *testing.T) {
}
})
t.Run("Client", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
defer p.ReleaseClients()
pc := p.Answer().Field(1, nil).Client()
p.Reject(errors.New("omg bbq"))
Expand All @@ -57,7 +57,7 @@ func TestPromiseFulfill(t *testing.T) {
t.Parallel()

t.Run("Done", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
done := p.Answer().Done()
msg, seg, _ := NewMessage(SingleSegment(nil))
defer msg.Release()
Expand All @@ -72,7 +72,7 @@ func TestPromiseFulfill(t *testing.T) {
}
})
t.Run("Struct", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
defer p.ReleaseClients()
ans := p.Answer()
msg, seg, _ := NewMessage(SingleSegment(nil))
Expand All @@ -92,7 +92,7 @@ func TestPromiseFulfill(t *testing.T) {
}
})
t.Run("Client", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
defer p.ReleaseClients()
pc := p.Answer().Field(1, nil).Client()

Expand Down
2 changes: 1 addition & 1 deletion answerqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (sr *StructReturner) Answer(m Method, pcall PipelineCaller) (*Answer, Relea
}
}
}
sr.p = NewPromise(m, pcall)
sr.p = NewPromise(m, pcall, nil)
ans := sr.p.Answer()
return ans, func() {
<-ans.Done()
Expand Down
60 changes: 13 additions & 47 deletions localpromise.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package capnp

import (
"context"
)

// ClientHook for a promise that will be resolved to some other capability
// at some point. Buffers calls in a queue until the promsie is fulfilled,
// then forwards them.
Expand All @@ -12,59 +8,29 @@ type localPromise struct {
}

// NewLocalPromise returns a client that will eventually resolve to a capability,
// supplied via the fulfiller.
// supplied via the resolver.
func NewLocalPromise[C ~ClientKind]() (C, Resolver[C]) {
lp := newLocalPromise()
p, f := NewPromisedClient(lp)
aq := NewAnswerQueue(Method{})
f := NewPromise(Method{}, aq, aq)
p := f.Answer().Client().AddRef()
return C(p), localResolver[C]{
lp: lp,
clientResolver: f,
p: f,
}
}

func newLocalPromise() localPromise {
return localPromise{aq: NewAnswerQueue(Method{})}
}

func (lp localPromise) Send(ctx context.Context, s Send) (*Answer, ReleaseFunc) {
return lp.aq.PipelineSend(ctx, nil, s)
}

func (lp localPromise) Recv(ctx context.Context, r Recv) PipelineCaller {
return lp.aq.PipelineRecv(ctx, nil, r)
}

func (lp localPromise) Brand() Brand {
return Brand{}
}

func (lp localPromise) Shutdown() {}

func (lp localPromise) String() string {
return "localPromise{...}"
}

func (lp localPromise) Fulfill(c Client) {
msg, seg := NewSingleSegmentMessage(nil)
capID := msg.CapTable().Add(c)
lp.aq.Fulfill(NewInterface(seg, capID).ToPtr())
}

func (lp localPromise) Reject(err error) {
lp.aq.Reject(err)
}

type localResolver[C ~ClientKind] struct {
lp localPromise
clientResolver Resolver[Client]
p *Promise
}

func (lf localResolver[C]) Fulfill(c C) {
lf.lp.Fulfill(Client(c))
lf.clientResolver.Fulfill(Client(c))
msg, seg := NewSingleSegmentMessage(nil)
capID := msg.CapTable().Add(Client(c))
iface := NewInterface(seg, capID)
lf.p.Fulfill(iface.ToPtr())
lf.p.ReleaseClients()
}

func (lf localResolver[C]) Reject(err error) {
lf.lp.Reject(err)
lf.clientResolver.Reject(err)
lf.p.Reject(err)
lf.p.ReleaseClients()
}
2 changes: 1 addition & 1 deletion rpc/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(), _ *rc.Releaser, _ er
func (ans *ansent) setPipelineCaller(m capnp.Method, pcall capnp.PipelineCaller) {
if !ans.flags.Contains(resultsReady) {
ans.pcall = pcall
ans.promise = capnp.NewPromise(m, pcall)
ans.promise = capnp.NewPromise(m, pcall, nil)
}
}

Expand Down
2 changes: 1 addition & 1 deletion rpc/question.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (c *lockedConn) newQuestion(method capnp.Method) *question {
release: func() {},
finishMsgSend: make(chan struct{}),
}
q.p = capnp.NewPromise(method, q) // TODO(someday): customize error message for bootstrap
q.p = capnp.NewPromise(method, q, nil) // TODO(someday): customize error message for bootstrap
c.setAnswerQuestion(q.p.Answer(), q)
if int(q.id) == len(c.lk.questions) {
c.lk.questions = append(c.lk.questions, q)
Expand Down

0 comments on commit 51facf7

Please sign in to comment.