From 541bc9f4593dc55fb51a4cc9592058a6a6f075eb Mon Sep 17 00:00:00 2001 From: Adam Azarchs Date: Thu, 3 Aug 2023 14:56:32 -0700 Subject: [PATCH] feat: make SerializeState cancellable SerializeState is called during post-processing, but also from the web UI. Under some circumstances, it can take a very long time to serialize the data, so it's entirely possible for requests to time out before they're returned. Worse, clients may retry, which can result in a pile-up. When a request times out, we want to abort collection of that data. While we're at it, add more tasks and regions for tracing the process. Note this includes breaking changes to the public go API. --- cmd/mrp/webserver.go | 22 +++++++++++----- martian/core/node.go | 36 ++++++++++++++++++-------- martian/core/pipestance.go | 42 +++++++++++++++++++------------ martian/core/post_process.go | 6 ++++- martian/core/post_process_test.go | 4 +-- martian/core/resolve_test.go | 4 +-- martian/core/runloop_test.go | 2 +- martian/core/stage.go | 29 +++++++++++++++++---- 8 files changed, 102 insertions(+), 43 deletions(-) diff --git a/cmd/mrp/webserver.go b/cmd/mrp/webserver.go index 89c8e957..cafee123 100644 --- a/cmd/mrp/webserver.go +++ b/cmd/mrp/webserver.go @@ -29,20 +29,20 @@ import ( "github.com/martian-lang/martian/martian/util" ) -func getFinalState(rt *core.Runtime, pipestance *core.Pipestance) []*core.NodeInfo { +func getFinalState(ctx context.Context, rt *core.Runtime, pipestance *core.Pipestance) []*core.NodeInfo { var target []*core.NodeInfo if err := rt.GetSerializationInto(pipestance.GetPath(), core.FinalState, &target); err == nil { return target } - return pipestance.SerializeState() + return pipestance.SerializeState(ctx) } -func getPerf(rt *core.Runtime, pipestance *core.Pipestance) []*core.NodePerfInfo { +func getPerf(ctx context.Context, rt *core.Runtime, pipestance *core.Pipestance) []*core.NodePerfInfo { var target []*core.NodePerfInfo if err := rt.GetSerializationInto(pipestance.GetPath(), core.Perf, &target); err == nil { return target } - return pipestance.SerializePerf() + return pipestance.SerializePerf(ctx) } func runWebServer( @@ -309,9 +309,14 @@ func (self *mrpWebServer) getState(w http.ResponseWriter, req *http.Request) { } pipestance := self.pipestanceBox.getPipestance() state := api.PipestanceState{ - Nodes: getFinalState(self.rt, pipestance), + Nodes: getFinalState(req.Context(), self.rt, pipestance), Info: self.pipestanceBox.info, } + if err := req.Context().Err(); err != nil { + // Don't sending bytes if the request was canceled. + http.Error(w, err.Error(), http.StatusRequestTimeout) + return + } self.mutex.Lock() bytes, err := json.Marshal(&state) self.mutex.Unlock() @@ -342,7 +347,12 @@ func (self *mrpWebServer) getPerf(w http.ResponseWriter, req *http.Request) { } pipestance := self.pipestanceBox.getPipestance() state := api.PerfInfo{ - Nodes: getPerf(self.rt, pipestance), + Nodes: getPerf(req.Context(), self.rt, pipestance), + } + if err := req.Context().Err(); err != nil { + // Don't sending bytes if the request was canceled. + http.Error(w, err.Error(), http.StatusRequestTimeout) + return } bytes, err := json.Marshal(&state) if err != nil { diff --git a/martian/core/node.go b/martian/core/node.go index b9d7c14d..5c341ab5 100644 --- a/martian/core/node.go +++ b/martian/core/node.go @@ -7,11 +7,13 @@ package core import ( + "context" "fmt" "math" "os" "path" "regexp" + "runtime/trace" "sort" "strconv" "strings" @@ -761,13 +763,14 @@ func (self *Node) kill(message string) { } } -func (self *Node) postProcess() { +func (self *Node) postProcess(ctx context.Context) { + defer trace.StartRegion(ctx, "Node_postProcess").End() os.RemoveAll(self.top.journalPath) os.RemoveAll(self.top.tmpPath) var errs syntax.ErrorList for _, fork := range self.forks { - if err := fork.postProcess(); err != nil { + if err := fork.postProcess(ctx); err != nil { errs = append(errs, err) } } @@ -776,11 +779,11 @@ func (self *Node) postProcess() { } } -func (self *Node) cachePerf() { +func (self *Node) cachePerf(ctx context.Context) { if _, ok := self.vdrKill(); ok { // Cache all fork performance info if node can be VDR-ed. for _, fork := range self.forks { - fork.cachePerf() + fork.cachePerf(ctx) } } } @@ -891,11 +894,13 @@ func (self *Node) step() bool { } self.addFrontierNode(self) case Complete: + ctx, task := trace.NewTask(context.Background(), "step_Complete") + defer task.End() if vdr := self.top.rt.Config.VdrMode; vdr == VdrRolling || vdr == VdrStrict { for _, node := range self.prenodes { - node.getNode().cachePerf() + node.getNode().cachePerf(ctx) } - self.cachePerf() + self.cachePerf(ctx) } fallthrough case DisabledState: @@ -986,13 +991,20 @@ func (self *Node) refreshState(readOnly bool) { } // Serialization. -func (self *Node) serializeState() *NodeInfo { +func (self *Node) serializeState(ctx context.Context) *NodeInfo { + defer trace.StartRegion(ctx, "Node_serializeState").End() forks := make([]*ForkInfo, 0, len(self.forks)) for _, fork := range self.forks { - forks = append(forks, fork.serializeState()) + if ctx.Err() != nil { + return nil + } + forks = append(forks, fork.serializeState(ctx)) } edges := make([]EdgeInfo, 0, len(self.directPrenodes)) for _, prenode := range self.directPrenodes { + if ctx.Err() != nil { + return nil + } edges = append(edges, EdgeInfo{ From: prenode.GetFQName(), To: self.call.GetFqid(), @@ -1035,11 +1047,15 @@ func (self *Node) serializeState() *NodeInfo { return info } -func (self *Node) serializePerf() (*NodePerfInfo, []*VdrEvent) { +func (self *Node) serializePerf(ctx context.Context) (*NodePerfInfo, []*VdrEvent) { + defer trace.StartRegion(ctx, "Node_serializePerf").End() forks := make([]*ForkPerfInfo, 0, len(self.forks)) var storageEvents []*VdrEvent for _, fork := range self.forks { - forkSer, vdrKill := fork.serializePerf() + if ctx.Err() != nil { + return nil, nil + } + forkSer, vdrKill := fork.serializePerf(ctx) forks = append(forks, forkSer) if vdrKill != nil && self.call.Kind() != syntax.KindPipeline { storageEvents = append(storageEvents, vdrKill.Events...) diff --git a/martian/core/pipestance.go b/martian/core/pipestance.go index cace9bca..596cd26c 100644 --- a/martian/core/pipestance.go +++ b/martian/core/pipestance.go @@ -653,47 +653,54 @@ func (self *Pipestance) Reset() error { return nil } -func (self *Pipestance) SerializeState() []*NodeInfo { +func (self *Pipestance) SerializeState(ctx context.Context) []*NodeInfo { nodes := self.allNodes() ser := make([]*NodeInfo, 0, len(nodes)) for _, node := range nodes { - ser = append(ser, node.serializeState()) + if ctx.Err() != nil { + return nil + } + ser = append(ser, node.serializeState(ctx)) } return ser } -func (self *Pipestance) SerializePerf() []*NodePerfInfo { +func (self *Pipestance) SerializePerf(ctx context.Context) []*NodePerfInfo { nodes := self.allNodes() ser := make([]*NodePerfInfo, 0, len(nodes)) for _, node := range nodes { - perf, _ := node.serializePerf() + if ctx.Err() != nil { + return nil + } + perf, _ := node.serializePerf(ctx) ser = append(ser, perf) } util.LogInfo("perform", "Serializing pipestance performance data.") - if len(ser) > 0 { + if len(ser) > 0 && ctx.Err() == nil { overallPerf := ser[0] - self.ComputeDiskUsage(overallPerf) + self.ComputeDiskUsage(ctx, overallPerf) overallPerf.HighMem = &self.node.top.rt.LocalJobManager.highMem } return ser } -func (self *Pipestance) Serialize(name MetadataFileName) interface{} { +func (self *Pipestance) Serialize(ctx context.Context, name MetadataFileName) interface{} { switch name { case FinalState: - return self.SerializeState() + return self.SerializeState(ctx) case Perf: - return self.SerializePerf() + return self.SerializePerf(ctx) default: panic(fmt.Sprintf("Unsupported serialization type: %v", name)) } } -func (self *Pipestance) ComputeDiskUsage(nodePerf *NodePerfInfo) *NodePerfInfo { +func (self *Pipestance) ComputeDiskUsage(ctx context.Context, nodePerf *NodePerfInfo) *NodePerfInfo { + defer trace.StartRegion(ctx, "ComputeDiskUsage").End() nodes := self.allNodes() allStorageEvents := make(StorageEventByTimestamp, 0, len(nodes)*2) for _, node := range nodes { - _, storageEvents := node.serializePerf() + _, storageEvents := node.serializePerf(ctx) for _, ev := range storageEvents { if ev.DeltaBytes != 0 { allStorageEvents = append(allStorageEvents, @@ -797,14 +804,16 @@ func (self *Pipestance) GetVersions() (string, string, error) { } func (self *Pipestance) PostProcess() { - self.node.postProcess() + ctx, task := trace.NewTask(context.Background(), "PostProcess") + defer task.End() + self.node.postProcess(ctx) start, _ := self.metadata.readRawBytes(TimestampFile) start = append(start, "\nend: "...) if err := self.metadata.WriteRawBytes(TimestampFile, append(start, util.Timestamp()...)); err != nil { util.LogError(err, "runtime", "Error writing completion timestamp.") } - if err := self.Immortalize(false); err != nil { + if err := self.Immortalize(ctx, false); err != nil { util.LogError(err, "runtime", "Error finalizing pipestance state.") } @@ -814,19 +823,20 @@ func (self *Pipestance) PostProcess() { // for posterity. // // Unless force is true, this is only permitted for locked pipestances. -func (self *Pipestance) Immortalize(force bool) error { +func (self *Pipestance) Immortalize(ctx context.Context, force bool) error { + defer trace.StartRegion(ctx, "Immortalize").End() if !force && self.readOnly() { return &RuntimeError{"Pipestance is in read only mode."} } self.metadata.loadCache() var errs syntax.ErrorList if !self.metadata.exists(Perf) { - if err := self.metadata.Write(Perf, self.SerializePerf()); err != nil { + if err := self.metadata.Write(Perf, self.SerializePerf(ctx)); err != nil { errs = append(errs, err) } } if !self.metadata.exists(FinalState) { - if err := self.metadata.Write(FinalState, self.SerializeState()); err != nil { + if err := self.metadata.Write(FinalState, self.SerializeState(ctx)); err != nil { errs = append(errs, err) } } diff --git a/martian/core/post_process.go b/martian/core/post_process.go index 4cb86e26..d017e64d 100644 --- a/martian/core/post_process.go +++ b/martian/core/post_process.go @@ -8,11 +8,13 @@ package core import ( "bytes" + "context" "encoding/json" "fmt" "os" "path" "path/filepath" + "runtime/trace" "sort" "strconv" "strings" @@ -21,7 +23,9 @@ import ( "github.com/martian-lang/martian/martian/util" ) -func (self *Fork) postProcess() error { +func (self *Fork) postProcess(ctx context.Context) error { + defer trace.StartRegion(ctx, "Fork_postProcess").End() + ro := self.node.call.ResolvedOutputs() if ro == nil { return nil diff --git a/martian/core/post_process_test.go b/martian/core/post_process_test.go index a11c0794..f2f22677 100644 --- a/martian/core/post_process_test.go +++ b/martian/core/post_process_test.go @@ -181,7 +181,7 @@ func TestPostProcess(t *testing.T) { } var buf strings.Builder util.SetPrintLogger(&buf) - if err := fork.postProcess(); err != nil { + if err := fork.postProcess(context.Background()); err != nil { t.Error(err) } util.SetPrintLogger(&devNull) @@ -419,7 +419,7 @@ func TestPostProcessEmpties(t *testing.T) { } var buf strings.Builder util.SetPrintLogger(&buf) - if err := fork.postProcess(); err != nil { + if err := fork.postProcess(context.Background()); err != nil { t.Error(err) } util.SetPrintLogger(&devNull) diff --git a/martian/core/resolve_test.go b/martian/core/resolve_test.go index e7ac707c..0f860a9e 100644 --- a/martian/core/resolve_test.go +++ b/martian/core/resolve_test.go @@ -395,7 +395,7 @@ func TestResolveSimplePipelineOutputs(t *testing.T) { } var buf strings.Builder util.SetPrintLogger(&buf) - if err := fork.postProcess(); err != nil { + if err := fork.postProcess(context.Background()); err != nil { t.Error(err) } util.SetPrintLogger(&devNull) @@ -595,7 +595,7 @@ func TestResolvePipelineOutputs(t *testing.T) { } var buf strings.Builder util.SetPrintLogger(&buf) - if err := fork.postProcess(); err != nil { + if err := fork.postProcess(context.Background()); err != nil { t.Error(err) } util.SetPrintLogger(&devNull) diff --git a/martian/core/runloop_test.go b/martian/core/runloop_test.go index 842e9e4d..d682803c 100644 --- a/martian/core/runloop_test.go +++ b/martian/core/runloop_test.go @@ -150,7 +150,7 @@ func TestPipestanceRun(t *testing.T) { } } } - nodeInfos := pipestance.SerializeState() + nodeInfos := pipestance.SerializeState(context.Background()) if len(nodeInfos) != 22 { t.Errorf("node count %d != 22", len(nodeInfos)) } diff --git a/martian/core/stage.go b/martian/core/stage.go index 79817c4c..f843aa1b 100644 --- a/martian/core/stage.go +++ b/martian/core/stage.go @@ -8,12 +8,14 @@ package core import ( "bytes" + "context" "encoding/json" "errors" "fmt" "math" "os" "path" + "runtime/trace" "strings" "sync" "time" @@ -1415,8 +1417,8 @@ func (self *Fork) printUpdateIfNeeded() { } } -func (self *Fork) cachePerf() { - perfInfo, vdrKillReport := self.serializePerf() +func (self *Fork) cachePerf(ctx context.Context) { + perfInfo, vdrKillReport := self.serializePerf(ctx) self.perfCache = &ForkPerfCache{perfInfo, vdrKillReport} } @@ -1480,7 +1482,8 @@ func (self *Fork) getAlarms(alarms *strings.Builder) { } } -func (self *Fork) serializeState() *ForkInfo { +func (self *Fork) serializeState(ctx context.Context) *ForkInfo { + defer trace.StartRegion(ctx, "Fork_serializeState").End() argbindings := self.node.inputBindingInfo(self.forkId) outputs := self.node.outputBindingInfo(self.forkId) bindings := &ForkBindingsInfo{ @@ -1489,6 +1492,9 @@ func (self *Fork) serializeState() *ForkInfo { } chunks := make([]*ChunkInfo, 0, len(self.chunks)) for _, chunk := range self.chunks { + if ctx.Err() != nil { + return nil + } chunks = append(chunks, chunk.serializeState()) } return &ForkInfo{ @@ -1520,7 +1526,8 @@ func (self *Fork) getStages() []*StagePerfInfo { return stages } -func (self *Fork) serializePerf() (*ForkPerfInfo, *VDRKillReport) { +func (self *Fork) serializePerf(ctx context.Context) (*ForkPerfInfo, *VDRKillReport) { + defer trace.StartRegion(ctx, "Fork_serializePerf").End() if self.perfCache != nil { // Use cached performance information if it exists. return self.perfCache.perfInfo, self.perfCache.vdrKillReport @@ -1529,6 +1536,9 @@ func (self *Fork) serializePerf() (*ForkPerfInfo, *VDRKillReport) { chunks := make([]*ChunkPerfInfo, 0, len(self.chunks)) stats := make([]*PerfInfo, 0, len(self.chunks)+len(self.node.subnodes)+2) for _, chunk := range self.chunks { + if ctx.Err() != nil { + return nil, nil + } chunkSer := chunk.serializePerf() chunks = append(chunks, chunkSer) if chunkSer.ChunkStats != nil { @@ -1557,14 +1567,23 @@ func (self *Fork) serializePerf() (*ForkPerfInfo, *VDRKillReport) { killReports := make([]*VDRKillReport, 1, len(self.node.subnodes)+1) killReports[0], _ = self.getVdrKillReport() for _, node := range self.node.subnodes { + if ctx.Err() != nil { + return nil, nil + } for _, subfork := range node.matchForks(self.forkId) { - subforkSer, subforkKillReport := subfork.serializePerf() + if ctx.Err() != nil { + return nil, nil + } + subforkSer, subforkKillReport := subfork.serializePerf(ctx) stats = append(stats, subforkSer.ForkStats) if subforkKillReport != nil { killReports = append(killReports, subforkKillReport) } } } + if ctx.Err() != nil { + return nil, nil + } killReport := mergeVDRKillReports(killReports) fpaths, _ := self.metadata.enumerateFiles()