diff --git a/source/extensions/common/wasm/context.cc b/source/extensions/common/wasm/context.cc index 41235f914f41..dc14b943e93a 100644 --- a/source/extensions/common/wasm/context.cc +++ b/source/extensions/common/wasm/context.cc @@ -787,11 +787,17 @@ uint32_t Context::getHeaderMapSize(HeaderMapType type) { // Buffer -Buffer::Instance* Context::getBuffer(BufferType type) { +const Buffer::Instance* Context::getBuffer(BufferType type) { switch (type) { case BufferType::HttpRequestBody: + if (buffering_request_body_) { + return decoder_callbacks_->decodingBuffer(); + } return request_body_buffer_; case BufferType::HttpResponseBody: + if (buffering_response_body_) { + return encoder_callbacks_->encodingBuffer(); + } return response_body_buffer_; case BufferType::NetworkDownstreamData: return network_downstream_data_buffer_; @@ -1196,22 +1202,28 @@ Http::FilterHeadersStatus Context::onRequestHeaders() { return Http::FilterHeadersStatus::StopIteration; } -Http::FilterDataStatus Context::onRequestBody(int body_buffer_length, bool end_of_stream) { +Http::FilterDataStatus Context::onRequestBody(bool end_of_stream) { if (!wasm_->on_request_body_) { return Http::FilterDataStatus::Continue; } DeferAfterCallActions actions(this); + const auto buffer = getBuffer(BufferType::HttpRequestBody); + const auto body_len = (buffer == nullptr) ? 0 : buffer->length(); switch (wasm_ - ->on_request_body_(this, id_, static_cast(body_buffer_length), + ->on_request_body_(this, id_, static_cast(body_len), static_cast(end_of_stream)) .u64_) { case 0: + buffering_request_body_ = false; return Http::FilterDataStatus::Continue; case 1: + buffering_request_body_ = true; return Http::FilterDataStatus::StopIterationAndBuffer; case 2: + buffering_request_body_ = false; return Http::FilterDataStatus::StopIterationAndWatermark; default: + buffering_request_body_ = false; return Http::FilterDataStatus::StopIterationNoBuffer; } } @@ -1257,22 +1269,28 @@ Http::FilterHeadersStatus Context::onResponseHeaders() { return Http::FilterHeadersStatus::StopIteration; } -Http::FilterDataStatus Context::onResponseBody(int body_buffer_length, bool end_of_stream) { +Http::FilterDataStatus Context::onResponseBody(bool end_of_stream) { if (!wasm_->on_response_body_) { return Http::FilterDataStatus::Continue; } DeferAfterCallActions actions(this); + const auto buffer = getBuffer(BufferType::HttpResponseBody); + const auto body_len = (buffer == nullptr) ? 0 : buffer->length(); switch (wasm_ - ->on_response_body_(this, id_, static_cast(body_buffer_length), + ->on_response_body_(this, id_, static_cast(body_len), static_cast(end_of_stream)) .u64_) { case 0: + buffering_response_body_ = false; return Http::FilterDataStatus::Continue; case 1: + buffering_response_body_ = true; return Http::FilterDataStatus::StopIterationAndBuffer; case 2: + buffering_response_body_ = false; return Http::FilterDataStatus::StopIterationAndWatermark; default: + buffering_response_body_ = false; return Http::FilterDataStatus::StopIterationNoBuffer; } } @@ -1590,7 +1608,7 @@ Http::FilterHeadersStatus Context::decodeHeaders(Http::HeaderMap& headers, bool Http::FilterDataStatus Context::decodeData(Buffer::Instance& data, bool end_stream) { request_body_buffer_ = &data; end_of_stream_ = end_stream; - auto result = onRequestBody(data.length(), end_stream); + auto result = onRequestBody(end_stream); request_body_buffer_ = nullptr; return result; } @@ -1628,7 +1646,7 @@ Http::FilterHeadersStatus Context::encodeHeaders(Http::HeaderMap& headers, bool Http::FilterDataStatus Context::encodeData(Buffer::Instance& data, bool end_stream) { response_body_buffer_ = &data; end_of_stream_ = end_stream; - auto result = onResponseBody(data.length(), end_stream); + auto result = onResponseBody(end_stream); response_body_buffer_ = nullptr; return result; } diff --git a/source/extensions/common/wasm/context.h b/source/extensions/common/wasm/context.h index 1944a27604f4..30d4107d8803 100644 --- a/source/extensions/common/wasm/context.h +++ b/source/extensions/common/wasm/context.h @@ -129,12 +129,12 @@ class Context : public Logger::Loggable, virtual void onUpstreamConnectionClose(PeerType); // HTTP Filter Stream Request Downcalls. virtual Http::FilterHeadersStatus onRequestHeaders(); - virtual Http::FilterDataStatus onRequestBody(int body_buffer_length, bool end_of_stream); + virtual Http::FilterDataStatus onRequestBody(bool end_of_stream); virtual Http::FilterTrailersStatus onRequestTrailers(); virtual Http::FilterMetadataStatus onRequestMetadata(); // HTTP Filter Stream Response Downcalls. virtual Http::FilterHeadersStatus onResponseHeaders(); - virtual Http::FilterDataStatus onResponseBody(int body_buffer_length, bool end_of_stream); + virtual Http::FilterDataStatus onResponseBody(bool end_of_stream); virtual Http::FilterTrailersStatus onResponseTrailers(); virtual Http::FilterMetadataStatus onResponseMetadata(); // Async Response Downcalls on any Context. @@ -265,7 +265,7 @@ class Context : public Logger::Loggable, virtual uint32_t getHeaderMapSize(HeaderMapType type); // Buffer - virtual Buffer::Instance* getBuffer(BufferType type); + virtual const Buffer::Instance* getBuffer(BufferType type); bool end_of_stream() { return end_of_stream_; } // HTTP @@ -478,6 +478,8 @@ class Context : public Logger::Loggable, // Temporary state. ProtobufWkt::Struct temporary_metadata_; bool end_of_stream_; + bool buffering_request_body_ = false; + bool buffering_response_body_ = false; // MB: must be a node-type map as we take persistent references to the entries. std::map http_request_; diff --git a/test/extensions/filters/http/wasm/test_data/Makefile b/test/extensions/filters/http/wasm/test_data/Makefile index 782b1e8eb872..b2124c0fcaa9 100644 --- a/test/extensions/filters/http/wasm/test_data/Makefile +++ b/test/extensions/filters/http/wasm/test_data/Makefile @@ -1,3 +1,3 @@ -all: headers_cpp.wasm async_call_cpp.wasm metadata_cpp.wasm grpc_call_cpp.wasm shared_cpp.wasm queue_cpp.wasm http_callout_cpp.wasm grpc_callout_cpp.wasm +all: headers_cpp.wasm async_call_cpp.wasm metadata_cpp.wasm grpc_call_cpp.wasm shared_cpp.wasm queue_cpp.wasm body_cpp.wasm http_callout_cpp.wasm grpc_callout_cpp.wasm include ../../../../../../api/wasm/cpp/Makefile.base_lite diff --git a/test/extensions/filters/http/wasm/test_data/body_cpp.cc b/test/extensions/filters/http/wasm/test_data/body_cpp.cc new file mode 100644 index 000000000000..74415d391db2 --- /dev/null +++ b/test/extensions/filters/http/wasm/test_data/body_cpp.cc @@ -0,0 +1,76 @@ +// NOLINT(namespace-envoy) +#include +#include +#include +#include + +#include "proxy_wasm_intrinsics.h" + +class ExampleContext : public Context { +public: + explicit ExampleContext(uint32_t id, RootContext* root) : Context(id, root) {} + + FilterHeadersStatus onRequestHeaders(uint32_t) override; + FilterHeadersStatus onResponseHeaders(uint32_t) override; + FilterDataStatus onRequestBody(size_t body_buffer_length, bool end_of_stream) override; + FilterDataStatus onResponseBody(size_t body_buffer_length, bool end_of_stream) override; + +private: + FilterDataStatus onBody(BufferType bt, size_t bufLen, bool end); + static void logBody(BufferType bt); + + std::string test_op_; + int num_chunks_ = 0; +}; +static RegisterContextFactory register_ExampleContext(CONTEXT_FACTORY(ExampleContext)); + +FilterHeadersStatus ExampleContext::onRequestHeaders(uint32_t) { + test_op_ = getRequestHeader("x-test-operation")->toString(); + return FilterHeadersStatus::Continue; +} + +FilterHeadersStatus ExampleContext::onResponseHeaders(uint32_t) { + test_op_ = getResponseHeader("x-test-operation")->toString(); + return FilterHeadersStatus::Continue; +} + +void ExampleContext::logBody(BufferType bt) { + size_t bufferedSize; + uint32_t flags; + getBufferStatus(bt, &bufferedSize, &flags); + auto body = getBufferBytes(bt, 0, bufferedSize); + logError(std::string("onRequestBody ") + std::string(body->view())); +} + +FilterDataStatus ExampleContext::onBody(BufferType bt, size_t bufLen, bool end) { + if (test_op_ == "ReadBody") { + auto body = getBufferBytes(bt, 0, bufLen); + logError("onRequestBody " + std::string(body->view())); + + } else if (test_op_ == "BufferBody") { + logBody(bt); + return end ? FilterDataStatus::Continue : FilterDataStatus::StopIterationAndBuffer; + + } else if (test_op_ == "BufferTwoBodies") { + logBody(bt); + num_chunks_++; + if (end || num_chunks_ > 2) { + return FilterDataStatus::Continue; + } + return FilterDataStatus::StopIterationAndBuffer; + + } else { + // This is a test and the test was configured incorrectly. + logError("Invalid test op " + test_op_); + abort(); + } + return FilterDataStatus::Continue; +} + +FilterDataStatus ExampleContext::onRequestBody(size_t body_buffer_length, bool end_of_stream) { + return onBody(BufferType::HttpRequestBody, body_buffer_length, end_of_stream); +} + +FilterDataStatus ExampleContext::onResponseBody(size_t body_buffer_length, bool end_of_stream) { + return onBody(BufferType::HttpResponseBody, body_buffer_length, end_of_stream); +} \ No newline at end of file diff --git a/test/extensions/filters/http/wasm/test_data/body_cpp.wasm b/test/extensions/filters/http/wasm/test_data/body_cpp.wasm new file mode 100644 index 000000000000..df0a34292e50 Binary files /dev/null and b/test/extensions/filters/http/wasm/test_data/body_cpp.wasm differ diff --git a/test/extensions/filters/http/wasm/wasm_filter_test.cc b/test/extensions/filters/http/wasm/wasm_filter_test.cc index 561378f257e2..26de3ede73bf 100644 --- a/test/extensions/filters/http/wasm/wasm_filter_test.cc +++ b/test/extensions/filters/http/wasm/wasm_filter_test.cc @@ -203,6 +203,122 @@ TEST_P(WasmHttpFilterTest, HeadersOnlyRequestHeadersAndBody) { filter_->onDestroy(); } +// Script that reads the body. +TEST_P(WasmHttpFilterTest, BodyRequestReadBody) { + setupConfig(TestEnvironment::readFileToStringForTest(TestEnvironment::substitute( + "{{ test_rundir }}/test/extensions/filters/http/wasm/test_data/body_cpp.wasm"))); + setupFilter(); + EXPECT_CALL(*filter_, + scriptLog_(spdlog::level::err, Eq(absl::string_view("onRequestBody hello")))); + Http::TestHeaderMapImpl request_headers{{":path", "/"}, {"x-test-operation", "ReadBody"}}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, true)); + Buffer::OwnedImpl data("hello"); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data, true)); + filter_->onDestroy(); +} + +// Script that buffers the body. +TEST_P(WasmHttpFilterTest, BodyRequestBufferBody) { + setupConfig(TestEnvironment::readFileToStringForTest(TestEnvironment::substitute( + "{{ test_rundir }}/test/extensions/filters/http/wasm/test_data/body_cpp.wasm"))); + setupFilter(); + + Http::TestHeaderMapImpl request_headers{{":path", "/"}, + {"x-test-operation", "BufferBody"}}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, false)); + + Buffer::OwnedImpl bufferedBody; + EXPECT_CALL(decoder_callbacks_, decodingBuffer()).WillRepeatedly(Return(&bufferedBody)); + + Buffer::OwnedImpl data1("hello"); + bufferedBody.add(data1); + EXPECT_CALL(*filter_, + scriptLog_(spdlog::level::err, Eq(absl::string_view("onRequestBody hello")))) + .Times(1); + EXPECT_EQ(Http::FilterDataStatus::StopIterationAndBuffer, filter_->decodeData(data1, false)); + + Buffer::OwnedImpl data2(" again "); + bufferedBody.add(data2); + EXPECT_CALL(*filter_, + scriptLog_(spdlog::level::err, Eq(absl::string_view("onRequestBody hello again ")))) + .Times(1); + EXPECT_EQ(Http::FilterDataStatus::StopIterationAndBuffer, filter_->decodeData(data2, false)); + + EXPECT_CALL(*filter_, scriptLog_(spdlog::level::err, + Eq(absl::string_view("onRequestBody hello again hello")))) + .Times(1); + Buffer::OwnedImpl data3("hello"); + bufferedBody.add(data3); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data3, true)); + + // Verify that the response still works even though we buffered the request. + Http::TestHeaderMapImpl response_headers{{":status", "200"}, + {"x-test-operation", "ReadBody"}}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers, false)); + // Should not buffer this time + EXPECT_CALL(*filter_, + scriptLog_(spdlog::level::err, Eq(absl::string_view("onRequestBody hello")))) + .Times(2); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(data1, false)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->encodeData(data1, true)); + + filter_->onDestroy(); +} + +// Script that buffers the first part of the body and streams the rest +TEST_P(WasmHttpFilterTest, BodyRequestBufferThenStreamBody) { + setupConfig(TestEnvironment::readFileToStringForTest(TestEnvironment::substitute( + "{{ test_rundir }}/test/extensions/filters/http/wasm/test_data/body_cpp.wasm"))); + setupFilter(); + + Http::TestHeaderMapImpl request_headers{{":path", "/"}}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, true)); + + Buffer::OwnedImpl bufferedBody; + EXPECT_CALL(decoder_callbacks_, decodingBuffer()).WillRepeatedly(Return(&bufferedBody)); + + Http::TestHeaderMapImpl response_headers{{":status", "200"}, + {"x-test-operation", "BufferTwoBodies"}}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers, false)); + + Buffer::OwnedImpl data1("hello"); + EXPECT_CALL(*filter_, + scriptLog_(spdlog::level::err, Eq(absl::string_view("onRequestBody hello")))) + .Times(1); + EXPECT_EQ(Http::FilterDataStatus::StopIterationAndBuffer, filter_->decodeData(data1, false)); + bufferedBody.add(data1); + + Buffer::OwnedImpl data2(", there, "); + bufferedBody.add(data2); + EXPECT_CALL(*filter_, + scriptLog_(spdlog::level::err, Eq(absl::string_view("onRequestBody hello, there, ")))) + .Times(1); + EXPECT_EQ(Http::FilterDataStatus::StopIterationAndBuffer, filter_->decodeData(data2, false)); + + // Previous callbacks returned "Buffer" so we have buffered so far + Buffer::OwnedImpl data3("world!"); + bufferedBody.add(data3); + EXPECT_CALL(*filter_, scriptLog_(spdlog::level::err, + Eq(absl::string_view("onRequestBody hello, there, world!")))) + .Times(1); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data3, false)); + + // Last callback returned "continue" so we just see individual chunks. + Buffer::OwnedImpl data4("So it's "); + EXPECT_CALL(*filter_, + scriptLog_(spdlog::level::err, Eq(absl::string_view("onRequestBody So it's ")))) + .Times(1); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data4, false)); + + Buffer::OwnedImpl data5("goodbye, then!"); + EXPECT_CALL(*filter_, + scriptLog_(spdlog::level::err, Eq(absl::string_view("onRequestBody goodbye, then!")))) + .Times(1); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter_->decodeData(data5, true)); + + filter_->onDestroy(); +} + // Script testing AccessLog::Instance::log. TEST_P(WasmHttpFilterTest, AccessLog) { setupConfig(TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(