Skip to content

Commit 5cf9df2

Browse files
committed
stop iteration
Signed-off-by: mathetake <takeshi@tetrate.io>
1 parent 31f3184 commit 5cf9df2

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

include/proxy-wasm/context.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ class ContextBase : public RootInterface,
170170
// Called before deleting the context.
171171
virtual void destroy();
172172

173+
// Called to raise the flag which indicates that the context should stop iteration regardless of
174+
// returned filter status from WASM extensions. For example, we ignore
175+
// FilterHeadersStatus::Continue after a local reponse is sent by the host.
176+
void stopIteration() { stop_iteration_ = true; };
177+
173178
/**
174179
* Calls into the VM.
175180
* These are implemented by the proxy-independent host code. They are virtual to support some
@@ -388,6 +393,7 @@ class ContextBase : public RootInterface,
388393
std::shared_ptr<PluginBase> temp_plugin_; // Remove once ABI v0.1.0 is gone.
389394
bool in_vm_context_created_ = false;
390395
bool destroyed_ = false;
396+
bool stop_iteration_ = false;
391397
};
392398

393399
class DeferAfterCallActions {

src/context.cc

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,12 @@ FilterHeadersStatus ContextBase::onRequestHeaders(uint32_t headers, bool end_of_
476476
->on_request_headers_abi_02_(this, id_, headers,
477477
static_cast<uint32_t>(end_of_stream))
478478
.u64_;
479-
if (result > static_cast<uint64_t>(FilterHeadersStatus::StopAllIterationAndWatermark))
479+
480+
if (stop_iteration_) {
481+
return FilterHeadersStatus::StopIteration;
482+
} else if (result > static_cast<uint64_t>(FilterHeadersStatus::StopAllIterationAndWatermark)) {
480483
return FilterHeadersStatus::StopAllIterationAndWatermark;
484+
}
481485
return static_cast<FilterHeadersStatus>(result);
482486
}
483487

@@ -486,7 +490,7 @@ FilterDataStatus ContextBase::onRequestBody(uint32_t data_length, bool end_of_st
486490
DeferAfterCallActions actions(this);
487491
auto result =
488492
wasm_->on_request_body_(this, id_, data_length, static_cast<uint32_t>(end_of_stream)).u64_;
489-
if (result > static_cast<uint64_t>(FilterDataStatus::StopIterationNoBuffer))
493+
if (stop_iteration_ || result > static_cast<uint64_t>(FilterDataStatus::StopIterationNoBuffer))
490494
return FilterDataStatus::StopIterationNoBuffer;
491495
return static_cast<FilterDataStatus>(result);
492496
}
@@ -495,8 +499,9 @@ FilterTrailersStatus ContextBase::onRequestTrailers(uint32_t trailers) {
495499
CHECK_HTTP(on_request_trailers_, FilterTrailersStatus::Continue,
496500
FilterTrailersStatus::StopIteration);
497501
DeferAfterCallActions actions(this);
498-
if (static_cast<FilterTrailersStatus>(wasm_->on_request_trailers_(this, id_, trailers).u64_) ==
499-
FilterTrailersStatus::Continue) {
502+
if (!stop_iteration_ &&
503+
static_cast<FilterTrailersStatus>(wasm_->on_request_trailers_(this, id_, trailers).u64_) ==
504+
FilterTrailersStatus::Continue) {
500505
return FilterTrailersStatus::Continue;
501506
}
502507
return FilterTrailersStatus::StopIteration;
@@ -522,8 +527,12 @@ FilterHeadersStatus ContextBase::onResponseHeaders(uint32_t headers, bool end_of
522527
->on_response_headers_abi_02_(this, id_, headers,
523528
static_cast<uint32_t>(end_of_stream))
524529
.u64_;
525-
if (result > static_cast<uint64_t>(FilterHeadersStatus::StopAllIterationAndWatermark))
530+
531+
if (stop_iteration_) {
532+
return FilterHeadersStatus::StopIteration;
533+
} else if (result > static_cast<uint64_t>(FilterHeadersStatus::StopAllIterationAndWatermark)) {
526534
return FilterHeadersStatus::StopAllIterationAndWatermark;
535+
}
527536
return static_cast<FilterHeadersStatus>(result);
528537
}
529538

@@ -533,7 +542,8 @@ FilterDataStatus ContextBase::onResponseBody(uint32_t body_length, bool end_of_s
533542
DeferAfterCallActions actions(this);
534543
auto result =
535544
wasm_->on_response_body_(this, id_, body_length, static_cast<uint32_t>(end_of_stream)).u64_;
536-
if (result > static_cast<uint64_t>(FilterDataStatus::StopIterationNoBuffer))
545+
546+
if (stop_iteration_ || result > static_cast<uint64_t>(FilterDataStatus::StopIterationNoBuffer))
537547
return FilterDataStatus::StopIterationNoBuffer;
538548
return static_cast<FilterDataStatus>(result);
539549
}
@@ -542,8 +552,9 @@ FilterTrailersStatus ContextBase::onResponseTrailers(uint32_t trailers) {
542552
CHECK_HTTP(on_response_trailers_, FilterTrailersStatus::Continue,
543553
FilterTrailersStatus::StopIteration);
544554
DeferAfterCallActions actions(this);
545-
if (static_cast<FilterTrailersStatus>(wasm_->on_response_trailers_(this, id_, trailers).u64_) ==
546-
FilterTrailersStatus::Continue) {
555+
if (!stop_iteration_ &&
556+
static_cast<FilterTrailersStatus>(wasm_->on_response_trailers_(this, id_, trailers).u64_) ==
557+
FilterTrailersStatus::Continue) {
547558
return FilterTrailersStatus::Continue;
548559
}
549560
return FilterTrailersStatus::StopIteration;

0 commit comments

Comments
 (0)