Skip to content

Commit

Permalink
Merge pull request #530 from zenhack/handle-resolve-take-2
Browse files Browse the repository at this point in the history
WIP: Handle incoming resolve messages, take 2
  • Loading branch information
zenhack authored Jun 29, 2023
2 parents 390b049 + 4f8d2d8 commit 7e1bedd
Show file tree
Hide file tree
Showing 14 changed files with 572 additions and 192 deletions.
17 changes: 15 additions & 2 deletions answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type Promise struct {
// - Resolved. Fulfill or Reject has finished.

state mutex.Mutex[promiseState]

resolver Resolver[Ptr]
}

type promiseState struct {
Expand Down Expand Up @@ -64,11 +66,13 @@ type clientAndPromise struct {
}

// NewPromise creates a new unresolved promise. The PipelineCaller will
// be used to make pipelined calls before the promise resolves.
func NewPromise(m Method, pc PipelineCaller) *Promise {
// be used to make pipelined calls before the promise resolves. If resolver
// is not nil, calls to Fulfill will be forwarded to it.
func NewPromise(m Method, pc PipelineCaller, resolver Resolver[Ptr]) *Promise {
if pc == nil {
panic("NewPromise(nil)")
}

resolved := make(chan struct{})
p := &Promise{
method: m,
Expand All @@ -77,6 +81,7 @@ func NewPromise(m Method, pc PipelineCaller) *Promise {
signals: []func(){func() { close(resolved) }},
caller: pc,
}),
resolver: resolver,
}
p.ans.f.promise = p
p.ans.metadata = *NewMetadata()
Expand Down Expand Up @@ -152,6 +157,14 @@ func (p *Promise) Resolve(r Ptr, e error) {
return p.clients
})

if p.resolver != nil {
if e == nil {
p.resolver.Fulfill(r)
} else {
p.resolver.Reject(e)
}
}

// Pending resolution state: wait for clients to be fulfilled
// and calls to have answers.
res := resolution{p.method, r, e}
Expand Down
12 changes: 6 additions & 6 deletions answer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ var dummyMethod = Method{

func TestPromiseReject(t *testing.T) {
t.Run("Done", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
done := p.Answer().Done()
p.Reject(errors.New("omg bbq"))
select {
Expand All @@ -27,7 +27,7 @@ func TestPromiseReject(t *testing.T) {
}
})
t.Run("Struct", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
defer p.ReleaseClients()
ans := p.Answer()
p.Reject(errors.New("omg bbq"))
Expand All @@ -36,7 +36,7 @@ func TestPromiseReject(t *testing.T) {
}
})
t.Run("Client", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
defer p.ReleaseClients()
pc := p.Answer().Field(1, nil).Client()
p.Reject(errors.New("omg bbq"))
Expand All @@ -57,7 +57,7 @@ func TestPromiseFulfill(t *testing.T) {
t.Parallel()

t.Run("Done", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
done := p.Answer().Done()
msg, seg, _ := NewMessage(SingleSegment(nil))
defer msg.Release()
Expand All @@ -72,7 +72,7 @@ func TestPromiseFulfill(t *testing.T) {
}
})
t.Run("Struct", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
defer p.ReleaseClients()
ans := p.Answer()
msg, seg, _ := NewMessage(SingleSegment(nil))
Expand All @@ -92,7 +92,7 @@ func TestPromiseFulfill(t *testing.T) {
}
})
t.Run("Client", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{})
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
defer p.ReleaseClients()
pc := p.Answer().Field(1, nil).Client()

Expand Down
2 changes: 1 addition & 1 deletion answerqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (sr *StructReturner) Answer(m Method, pcall PipelineCaller) (*Answer, Relea
}
}
}
sr.p = NewPromise(m, pcall)
sr.p = NewPromise(m, pcall, nil)
ans := sr.p.Answer()
return ans, func() {
<-ans.Done()
Expand Down
19 changes: 19 additions & 0 deletions capability.go
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,22 @@ func (cs ClientSnapshot) IsPromise() bool {
return ret
}

// IsResolved returns true if the snapshot has resolved to its final value.
// If IsPromise() returns false, then this will also return false. Otherwise,
// it returns false before resolution and true afterwards.
func (cs ClientSnapshot) IsResolved() bool {
if cs.hook == nil {
return false
}
res, ok := cs.hook.Value().resolution.Get()
if !ok {
return false
}
return mutex.With1(res, func(s *resolveState) bool {
return s.isResolved()
})
}

// Send implements ClientHook.Send
func (cs ClientSnapshot) Send(ctx context.Context, s Send) (*Answer, ReleaseFunc) {
if cs.hook == nil {
Expand Down Expand Up @@ -817,6 +833,9 @@ func SetClientLeakFunc(clientLeakFunc func(msg string)) {
clientLeakFunc("leaked client created at:\n\n" + stack)
})
case ClientSnapshot:
if !c.IsValid() {
return
}
runtime.SetFinalizer(c.hook, func(c *rc.Ref[clientHook]) {
if !c.IsValid() {
return
Expand Down
2 changes: 1 addition & 1 deletion capability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func TestResolve(t *testing.T) {
}
t.Run("Clients", func(t *testing.T) {
test(t, "Waits for the full chain", func(t *testing.T, p1, p2 Client, r1, r2 Resolver[Client]) {
r1.Fulfill(p2)
r1.Fulfill(p2.AddRef())
ctx, cancel := context.WithTimeout(context.Background(), time.Second/10)
defer cancel()
require.NotNil(t, p1.Resolve(ctx), "blocks on second promise")
Expand Down
61 changes: 14 additions & 47 deletions localpromise.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package capnp

import (
"context"
)

// ClientHook for a promise that will be resolved to some other capability
// at some point. Buffers calls in a queue until the promsie is fulfilled,
// then forwards them.
Expand All @@ -12,59 +8,30 @@ type localPromise struct {
}

// NewLocalPromise returns a client that will eventually resolve to a capability,
// supplied via the fulfiller.
// supplied via the resolver.
func NewLocalPromise[C ~ClientKind]() (C, Resolver[C]) {
lp := newLocalPromise()
p, f := NewPromisedClient(lp)
aq := NewAnswerQueue(Method{})
f := NewPromise(Method{}, aq, aq)
p := f.Answer().Client().AddRef()
return C(p), localResolver[C]{
lp: lp,
clientResolver: f,
p: f,
}
}

func newLocalPromise() localPromise {
return localPromise{aq: NewAnswerQueue(Method{})}
}

func (lp localPromise) Send(ctx context.Context, s Send) (*Answer, ReleaseFunc) {
return lp.aq.PipelineSend(ctx, nil, s)
}

func (lp localPromise) Recv(ctx context.Context, r Recv) PipelineCaller {
return lp.aq.PipelineRecv(ctx, nil, r)
}

func (lp localPromise) Brand() Brand {
return Brand{}
}

func (lp localPromise) Shutdown() {}

func (lp localPromise) String() string {
return "localPromise{...}"
}

func (lp localPromise) Fulfill(c Client) {
msg, seg := NewSingleSegmentMessage(nil)
capID := msg.CapTable().Add(c)
lp.aq.Fulfill(NewInterface(seg, capID).ToPtr())
}

func (lp localPromise) Reject(err error) {
lp.aq.Reject(err)
}

type localResolver[C ~ClientKind] struct {
lp localPromise
clientResolver Resolver[Client]
p *Promise
}

func (lf localResolver[C]) Fulfill(c C) {
lf.lp.Fulfill(Client(c))
lf.clientResolver.Fulfill(Client(c))
msg, seg := NewSingleSegmentMessage(nil)
capID := msg.CapTable().Add(Client(c))
iface := NewInterface(seg, capID)
lf.p.Fulfill(iface.ToPtr())
lf.p.ReleaseClients()
msg.Release()
}

func (lf localResolver[C]) Reject(err error) {
lf.lp.Reject(err)
lf.clientResolver.Reject(err)
lf.p.Reject(err)
lf.p.ReleaseClients()
}
2 changes: 1 addition & 1 deletion rpc/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(), _ *rc.Releaser, _ er
func (ans *ansent) setPipelineCaller(m capnp.Method, pcall capnp.PipelineCaller) {
if !ans.flags.Contains(resultsReady) {
ans.pcall = pcall
ans.promise = capnp.NewPromise(m, pcall)
ans.promise = capnp.NewPromise(m, pcall, nil)
}
}

Expand Down
51 changes: 29 additions & 22 deletions rpc/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ type exportID uint32

// expent is an entry in a Conn's export table.
type expent struct {
snapshot capnp.ClientSnapshot
wireRefs uint32
isPromise bool
snapshot capnp.ClientSnapshot
wireRefs uint32

// Should be called when removing this entry from the exports table:
cancel context.CancelFunc
Expand Down Expand Up @@ -74,9 +73,11 @@ func (c *lockedConn) releaseExport(id exportID, count uint32) (capnp.ClientSnaps
c.lk.exports[id] = nil
c.lk.exportID.remove(id)
metadata := snapshot.Metadata()
syncutil.With(metadata, func() {
c.clearExportID(metadata)
})
if metadata != nil {
syncutil.With(metadata, func() {
c.clearExportID(metadata)
})
}
return snapshot, nil
case count > ent.wireRefs:
return capnp.ClientSnapshot{}, rpcerr.Failed(errors.New("export ID " + str.Utod(id) + " released too many references"))
Expand Down Expand Up @@ -203,7 +204,7 @@ func (c *lockedConn) sendSenderPromise(id exportID, d rpccp.CapDescriptor) {
// Conn before trying to use it again:
unlockedConn := (*Conn)(c)

waitErr := waitRef.Resolve(ctx)
waitErr := waitRef.Resolve1(ctx)
unlockedConn.withLocked(func(c *lockedConn) {
if len(c.lk.exports) <= int(id) || c.lk.exports[id] != ee {
// Export was removed from the table at some point;
Expand Down Expand Up @@ -366,33 +367,39 @@ func (e *embargo) Shutdown() {
// senderLoopback holds the salient information for a sender-loopback
// Disembargo message.
type senderLoopback struct {
id embargoID
question questionID
transform []capnp.PipelineOp
id embargoID
target parsedMessageTarget
}

func (sl *senderLoopback) buildDisembargo(msg rpccp.Message) error {
d, err := msg.NewDisembargo()
if err != nil {
return rpcerr.WrapFailed("build disembargo", err)
}
d.Context().SetSenderLoopback(uint32(sl.id))
tgt, err := d.NewTarget()
if err != nil {
return rpcerr.WrapFailed("build disembargo", err)
}
pa, err := tgt.NewPromisedAnswer()
if err != nil {
return rpcerr.WrapFailed("build disembargo", err)
}
oplist, err := pa.NewTransform(int32(len(sl.transform)))
if err != nil {
return rpcerr.WrapFailed("build disembargo", err)
}
switch sl.target.which {
case rpccp.MessageTarget_Which_promisedAnswer:
pa, err := tgt.NewPromisedAnswer()
if err != nil {
return rpcerr.WrapFailed("build disembargo", err)
}
oplist, err := pa.NewTransform(int32(len(sl.target.transform)))
if err != nil {
return rpcerr.WrapFailed("build disembargo", err)
}

d.Context().SetSenderLoopback(uint32(sl.id))
pa.SetQuestionId(uint32(sl.question))
for i, op := range sl.transform {
oplist.At(i).SetGetPointerField(op.Field)
pa.SetQuestionId(uint32(sl.target.promisedAnswer))
for i, op := range sl.target.transform {
oplist.At(i).SetGetPointerField(op.Field)
}
case rpccp.MessageTarget_Which_importedCap:
tgt.SetImportedCap(uint32(sl.target.importedCap))
default:
return errors.New("unknown variant for MessageTarget: " + str.Utod(sl.target.which))
}
return nil
}
21 changes: 18 additions & 3 deletions rpc/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,19 @@ type impent struct {
// importClient's generation matches the entry's generation before
// removing the entry from the table and sending a release message.
generation uint64

// If resolver is non-nil, then this is a promise (received as
// CapDescriptor_Which_senderPromise), and when a resolve message
// arrives we should use this to fulfill the promise locally.
resolver capnp.Resolver[capnp.Client]
}

// addImport returns a client that represents the given import,
// incrementing the number of references to this import from this vat.
// This is separate from the reference counting that capnp.Client does.
//
// The caller must be holding onto c.mu.
func (c *lockedConn) addImport(id importID) capnp.Client {
func (c *lockedConn) addImport(id importID, isPromise bool) capnp.Client {
if ent := c.lk.imports[id]; ent != nil {
ent.wireRefs++
client, ok := ent.wc.AddRef()
Expand All @@ -67,13 +72,23 @@ func (c *lockedConn) addImport(id importID) capnp.Client {
}
return client
}
client := capnp.NewClient(&importClient{
hook := &importClient{
c: (*Conn)(c),
id: id,
})
}
var (
client capnp.Client
resolver capnp.Resolver[capnp.Client]
)
if isPromise {
client, resolver = capnp.NewPromisedClient(hook)
} else {
client = capnp.NewClient(hook)
}
c.lk.imports[id] = &impent{
wc: client.WeakRef(),
wireRefs: 1,
resolver: resolver,
}
return client
}
Expand Down
Loading

0 comments on commit 7e1bedd

Please sign in to comment.